diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4TransportIT.java b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4TransportIT.java index bc24789341e..cf9791ce85a 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4TransportIT.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4TransportIT.java @@ -111,7 +111,7 @@ public class Netty4TransportIT extends ESNetty4IntegTestCase { } @Override - protected void handleRequest(TcpChannel channel, InboundMessage.RequestMessage request, int messageLengthBytes) throws IOException { + protected void handleRequest(TcpChannel channel, InboundMessage.Request request, int messageLengthBytes) throws IOException { super.handleRequest(channel, request, messageLengthBytes); channelProfileName = TransportSettings.DEFAULT_PROFILE; } diff --git a/plugins/transport-nio/src/test/java/org/elasticsearch/transport/nio/NioTransportIT.java b/plugins/transport-nio/src/test/java/org/elasticsearch/transport/nio/NioTransportIT.java index 087c3758bb9..d02be2cff9e 100644 --- a/plugins/transport-nio/src/test/java/org/elasticsearch/transport/nio/NioTransportIT.java +++ b/plugins/transport-nio/src/test/java/org/elasticsearch/transport/nio/NioTransportIT.java @@ -113,7 +113,7 @@ public class NioTransportIT extends NioIntegTestCase { } @Override - protected void handleRequest(TcpChannel channel, InboundMessage.RequestMessage request, int messageLengthBytes) throws IOException { + protected void handleRequest(TcpChannel channel, InboundMessage.Request request, int messageLengthBytes) throws IOException { super.handleRequest(channel, request, messageLengthBytes); channelProfileName = TransportSettings.DEFAULT_PROFILE; } diff --git a/server/src/main/java/org/elasticsearch/transport/InboundMessage.java b/server/src/main/java/org/elasticsearch/transport/InboundMessage.java index 44e3b017ed2..953bb86c6c5 100644 --- a/server/src/main/java/org/elasticsearch/transport/InboundMessage.java +++ b/server/src/main/java/org/elasticsearch/transport/InboundMessage.java @@ -106,9 +106,9 @@ public abstract class InboundMessage extends NetworkMessage implements Closeable features = Collections.emptySet(); } final String action = streamInput.readString(); - message = new RequestMessage(threadContext, remoteVersion, status, requestId, action, features, streamInput); + message = new Request(threadContext, remoteVersion, status, requestId, action, features, streamInput); } else { - message = new ResponseMessage(threadContext, remoteVersion, status, requestId, streamInput); + message = new Response(threadContext, remoteVersion, status, requestId, streamInput); } success = true; return message; @@ -138,13 +138,13 @@ public abstract class InboundMessage extends NetworkMessage implements Closeable } } - public static class RequestMessage extends InboundMessage { + public static class Request extends InboundMessage { private final String actionName; private final Set features; - RequestMessage(ThreadContext threadContext, Version version, byte status, long requestId, String actionName, Set features, - StreamInput streamInput) { + Request(ThreadContext threadContext, Version version, byte status, long requestId, String actionName, Set features, + StreamInput streamInput) { super(threadContext, version, status, requestId, streamInput); this.actionName = actionName; this.features = features; @@ -159,9 +159,9 @@ public abstract class InboundMessage extends NetworkMessage implements Closeable } } - public static class ResponseMessage extends InboundMessage { + public static class Response extends InboundMessage { - ResponseMessage(ThreadContext threadContext, Version version, byte status, long requestId, StreamInput streamInput) { + Response(ThreadContext threadContext, Version version, byte status, long requestId, StreamInput streamInput) { super(threadContext, version, status, requestId, streamInput); } } diff --git a/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java b/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java index 9431258f323..4b816c6a065 100644 --- a/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java +++ b/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java @@ -22,8 +22,10 @@ package org.elasticsearch.transport; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.NotifyOnceListener; +import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.CheckedSupplier; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput; @@ -32,49 +34,100 @@ import org.elasticsearch.common.lease.Releasables; import org.elasticsearch.common.metrics.MeanMetric; import org.elasticsearch.common.network.CloseableChannel; import org.elasticsearch.common.transport.NetworkExceptionHelper; +import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.core.internal.io.IOUtils; import org.elasticsearch.threadpool.ThreadPool; import java.io.IOException; +import java.util.Set; final class OutboundHandler { private static final Logger logger = LogManager.getLogger(OutboundHandler.class); private final MeanMetric transmittedBytesMetric = new MeanMetric(); + + private final String nodeName; + private final Version version; + private final String[] features; private final ThreadPool threadPool; private final BigArrays bigArrays; private final TransportLogger transportLogger; + private volatile TransportMessageListener messageListener = TransportMessageListener.NOOP_LISTENER; - OutboundHandler(ThreadPool threadPool, BigArrays bigArrays, TransportLogger transportLogger) { + OutboundHandler(String nodeName, Version version, String[] features, ThreadPool threadPool, BigArrays bigArrays, + TransportLogger transportLogger) { + this.nodeName = nodeName; + this.version = version; + this.features = features; this.threadPool = threadPool; this.bigArrays = bigArrays; this.transportLogger = transportLogger; } void sendBytes(TcpChannel channel, BytesReference bytes, ActionListener listener) { - channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis()); SendContext sendContext = new SendContext(channel, () -> bytes, listener); try { - internalSendMessage(channel, sendContext); + internalSend(channel, sendContext); } catch (IOException e) { // This should not happen as the bytes are already serialized throw new AssertionError(e); } } - void sendMessage(TcpChannel channel, OutboundMessage networkMessage, ActionListener listener) throws IOException { - channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis()); - MessageSerializer serializer = new MessageSerializer(networkMessage, bigArrays); - SendContext sendContext = new SendContext(channel, serializer, listener, serializer); - internalSendMessage(channel, sendContext); + /** + * Sends the request to the given channel. This method should be used to send {@link TransportRequest} + * objects back to the caller. + */ + void sendRequest(final DiscoveryNode node, final TcpChannel channel, final long requestId, final String action, + final TransportRequest request, final TransportRequestOptions options, final Version channelVersion, + final boolean compressRequest, final boolean isHandshake) throws IOException, TransportException { + Version version = Version.min(this.version, channelVersion); + OutboundMessage.Request message = new OutboundMessage.Request(threadPool.getThreadContext(), features, request, version, action, + requestId, isHandshake, compressRequest); + ActionListener listener = ActionListener.wrap(() -> + messageListener.onRequestSent(node, requestId, action, request, options)); + sendMessage(channel, message, listener); } /** - * sends a message to the given channel, using the given callbacks. + * Sends the response to the given channel. This method should be used to send {@link TransportResponse} + * objects back to the caller. + * + * @see #sendErrorResponse(Version, Set, TcpChannel, long, String, Exception) for sending error responses */ - private void internalSendMessage(TcpChannel channel, SendContext sendContext) throws IOException { + void sendResponse(final Version nodeVersion, final Set features, final TcpChannel channel, + final long requestId, final String action, final TransportResponse response, + final boolean compress, final boolean isHandshake) throws IOException { + Version version = Version.min(this.version, nodeVersion); + OutboundMessage.Response message = new OutboundMessage.Response(threadPool.getThreadContext(), features, response, version, + requestId, isHandshake, compress); + ActionListener listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, response)); + sendMessage(channel, message, listener); + } + + /** + * Sends back an error response to the caller via the given channel + */ + void sendErrorResponse(final Version nodeVersion, final Set features, final TcpChannel channel, final long requestId, + final String action, final Exception error) throws IOException { + Version version = Version.min(this.version, nodeVersion); + TransportAddress address = new TransportAddress(channel.getLocalAddress()); + RemoteTransportException tx = new RemoteTransportException(nodeName, address, action, error); + OutboundMessage.Response message = new OutboundMessage.Response(threadPool.getThreadContext(), features, tx, version, requestId, + false, false); + ActionListener listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, error)); + sendMessage(channel, message, listener); + } + + private void sendMessage(TcpChannel channel, OutboundMessage networkMessage, ActionListener listener) throws IOException { + MessageSerializer serializer = new MessageSerializer(networkMessage, bigArrays); + SendContext sendContext = new SendContext(channel, serializer, listener, serializer); + internalSend(channel, sendContext); + } + + private void internalSend(TcpChannel channel, SendContext sendContext) throws IOException { channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis()); BytesReference reference = sendContext.get(); try { @@ -91,6 +144,14 @@ final class OutboundHandler { return transmittedBytesMetric; } + void setMessageListener(TransportMessageListener listener) { + if (messageListener == TransportMessageListener.NOOP_LISTENER) { + messageListener = listener; + } else { + throw new IllegalStateException("Cannot set message listener twice"); + } + } + private static class MessageSerializer implements CheckedSupplier, Releasable { private final OutboundMessage message; diff --git a/server/src/main/java/org/elasticsearch/transport/TcpTransport.java b/server/src/main/java/org/elasticsearch/transport/TcpTransport.java index 6110752b421..8125d5bcb12 100644 --- a/server/src/main/java/org/elasticsearch/transport/TcpTransport.java +++ b/server/src/main/java/org/elasticsearch/transport/TcpTransport.java @@ -106,19 +106,15 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements private static final long NINETY_PER_HEAP_SIZE = (long) (JvmInfo.jvmInfo().getMem().getHeapMax().getBytes() * 0.9); private static final BytesReference EMPTY_BYTES_REFERENCE = new BytesArray(new byte[0]); - private final String[] features; - protected final Settings settings; private final CircuitBreakerService circuitBreakerService; - private final Version version; protected final ThreadPool threadPool; protected final BigArrays bigArrays; protected final PageCacheRecycler pageCacheRecycler; protected final NetworkService networkService; protected final Set profileSettings; - private static final TransportMessageListener NOOP_LISTENER = new TransportMessageListener() {}; - private volatile TransportMessageListener messageListener = NOOP_LISTENER; + private volatile TransportMessageListener messageListener = TransportMessageListener.NOOP_LISTENER; private final ConcurrentMap profileBoundAddresses = newConcurrentMap(); private final Map> serverChannels = newConcurrentMap(); @@ -137,34 +133,23 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements private final TransportKeepAlive keepAlive; private final InboundMessage.Reader reader; private final OutboundHandler outboundHandler; - private final String nodeName; public TcpTransport(Settings settings, Version version, ThreadPool threadPool, PageCacheRecycler pageCacheRecycler, CircuitBreakerService circuitBreakerService, NamedWriteableRegistry namedWriteableRegistry, NetworkService networkService) { this.settings = settings; this.profileSettings = getProfileSettings(settings); - this.version = version; this.threadPool = threadPool; this.bigArrays = new BigArrays(pageCacheRecycler, circuitBreakerService, CircuitBreaker.IN_FLIGHT_REQUESTS); this.pageCacheRecycler = pageCacheRecycler; this.circuitBreakerService = circuitBreakerService; this.networkService = networkService; this.transportLogger = new TransportLogger(); - this.outboundHandler = new OutboundHandler(threadPool, bigArrays, transportLogger); - this.handshaker = new TransportHandshaker(version, threadPool, - (node, channel, requestId, v) -> sendRequestToChannel(node, channel, requestId, - TransportHandshaker.HANDSHAKE_ACTION_NAME, new TransportHandshaker.HandshakeRequest(version), - TransportRequestOptions.EMPTY, v, false, true), - (v, features, channel, response, requestId) -> sendResponse(v, features, channel, response, requestId, - TransportHandshaker.HANDSHAKE_ACTION_NAME, false, true)); - this.keepAlive = new TransportKeepAlive(threadPool, this.outboundHandler::sendBytes); - this.reader = new InboundMessage.Reader(version, namedWriteableRegistry, threadPool.getThreadContext()); - this.nodeName = Node.NODE_NAME_SETTING.get(settings); - + String nodeName = Node.NODE_NAME_SETTING.get(settings); final Settings defaultFeatures = TransportSettings.DEFAULT_FEATURES_SETTING.get(settings); + String[] features; if (defaultFeatures == null) { - this.features = new String[0]; + features = new String[0]; } else { defaultFeatures.names().forEach(key -> { if (Booleans.parseBoolean(defaultFeatures.get(key)) == false) { @@ -172,8 +157,18 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements } }); // use a sorted set to present the features in a consistent order - this.features = new TreeSet<>(defaultFeatures.names()).toArray(new String[defaultFeatures.names().size()]); + features = new TreeSet<>(defaultFeatures.names()).toArray(new String[defaultFeatures.names().size()]); } + this.outboundHandler = new OutboundHandler(nodeName, version, features, threadPool, bigArrays, transportLogger); + + this.handshaker = new TransportHandshaker(version, threadPool, + (node, channel, requestId, v) -> outboundHandler.sendRequest(node, channel, requestId, + TransportHandshaker.HANDSHAKE_ACTION_NAME, new TransportHandshaker.HandshakeRequest(version), + TransportRequestOptions.EMPTY, v, false, true), + (v, features1, channel, response, requestId) -> outboundHandler.sendResponse(v, features1, channel, requestId, + TransportHandshaker.HANDSHAKE_ACTION_NAME, response, false, true)); + this.keepAlive = new TransportKeepAlive(threadPool, this.outboundHandler::sendBytes); + this.reader = new InboundMessage.Reader(version, namedWriteableRegistry, threadPool.getThreadContext()); } @Override @@ -182,8 +177,9 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements @Override public synchronized void setMessageListener(TransportMessageListener listener) { - if (messageListener == NOOP_LISTENER) { + if (messageListener == TransportMessageListener.NOOP_LISTENER) { messageListener = listener; + outboundHandler.setMessageListener(listener); } else { throw new IllegalStateException("Cannot set message listener twice"); } @@ -267,7 +263,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements throw new NodeNotConnectedException(node, "connection already closed"); } TcpChannel channel = channel(options.type()); - sendRequestToChannel(this.node, channel, requestId, action, request, options, getVersion(), compress); + outboundHandler.sendRequest(node, channel, requestId, action, request, options, getVersion(), compress, false); } } @@ -661,81 +657,6 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements */ protected abstract void stopInternal(); - private void sendRequestToChannel(final DiscoveryNode node, final TcpChannel channel, final long requestId, final String action, - final TransportRequest request, TransportRequestOptions options, Version channelVersion, - boolean compressRequest) throws IOException, TransportException { - sendRequestToChannel(node, channel, requestId, action, request, options, channelVersion, compressRequest, false); - } - - private void sendRequestToChannel(final DiscoveryNode node, final TcpChannel channel, final long requestId, final String action, - final TransportRequest request, TransportRequestOptions options, Version channelVersion, - boolean compressRequest, boolean isHandshake) throws IOException, TransportException { - Version version = Version.min(this.version, channelVersion); - OutboundMessage.Request message = new OutboundMessage.Request(threadPool.getThreadContext(), features, request, version, action, - requestId, isHandshake, compressRequest); - ActionListener listener = ActionListener.wrap(() -> - messageListener.onRequestSent(node, requestId, action, request, options)); - outboundHandler.sendMessage(channel, message, listener); - } - - /** - * Sends back an error response to the caller via the given channel - * - * @param nodeVersion the caller node version - * @param features the caller features - * @param channel the channel to send the response to - * @param error the error to return - * @param requestId the request ID this response replies to - * @param action the action this response replies to - */ - public void sendErrorResponse( - final Version nodeVersion, - final Set features, - final TcpChannel channel, - final Exception error, - final long requestId, - final String action) throws IOException { - Version version = Version.min(this.version, nodeVersion); - TransportAddress address = new TransportAddress(channel.getLocalAddress()); - RemoteTransportException tx = new RemoteTransportException(nodeName, address, action, error); - OutboundMessage.Response message = new OutboundMessage.Response(threadPool.getThreadContext(), features, tx, version, requestId, - false, false); - ActionListener listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, error)); - outboundHandler.sendMessage(channel, message, listener); - } - - /** - * Sends the response to the given channel. This method should be used to send {@link TransportResponse} objects back to the caller. - * - * @see #sendErrorResponse(Version, Set, TcpChannel, Exception, long, String) for sending back errors to the caller - */ - public void sendResponse( - final Version nodeVersion, - final Set features, - final TcpChannel channel, - final TransportResponse response, - final long requestId, - final String action, - final boolean compress) throws IOException { - sendResponse(nodeVersion, features, channel, response, requestId, action, compress, false); - } - - private void sendResponse( - final Version nodeVersion, - final Set features, - final TcpChannel channel, - final TransportResponse response, - final long requestId, - final String action, - boolean compress, - boolean isHandshake) throws IOException { - Version version = Version.min(this.version, nodeVersion); - OutboundMessage.Response message = new OutboundMessage.Response(threadPool.getThreadContext(), features, response, version, - requestId, isHandshake, compress); - ActionListener listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, response)); - outboundHandler.sendMessage(channel, message, listener); - } - /** * Handles inbound message that has been decoded. * @@ -913,7 +834,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements message.getStoredContext().restore(); threadContext.putTransient("_remote_address", remoteAddress); if (message.isRequest()) { - handleRequest(channel, (InboundMessage.RequestMessage) message, reference.length()); + handleRequest(channel, (InboundMessage.Request) message, reference.length()); } else { final TransportResponseHandler handler; long requestId = message.getRequestId(); @@ -999,7 +920,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements }); } - protected void handleRequest(TcpChannel channel, InboundMessage.RequestMessage message, int messageLengthBytes) throws IOException { + protected void handleRequest(TcpChannel channel, InboundMessage.Request message, int messageLengthBytes) throws IOException { final Set features = message.getFeatures(); final String profileName = channel.getProfile(); final String action = message.getActionName(); @@ -1021,8 +942,8 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements } else { getInFlightRequestBreaker().addWithoutBreaking(messageLengthBytes); } - transportChannel = new TcpTransportChannel(this, channel, action, requestId, version, features, profileName, - messageLengthBytes, message.isCompress()); + transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, features, + circuitBreakerService, messageLengthBytes, message.isCompress()); final TransportRequest request = reg.newRequest(stream); request.remoteAddress(new TransportAddress(channel.getRemoteAddress())); // in case we throw an exception, i.e. when the limit is hit, we don't want to verify @@ -1032,8 +953,8 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements } catch (Exception e) { // the circuit breaker tripped if (transportChannel == null) { - transportChannel = new TcpTransportChannel(this, channel, action, requestId, version, features, - profileName, 0, message.isCompress()); + transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, features, + circuitBreakerService, 0, message.isCompress()); } try { transportChannel.sendResponse(e); diff --git a/server/src/main/java/org/elasticsearch/transport/TcpTransportChannel.java b/server/src/main/java/org/elasticsearch/transport/TcpTransportChannel.java index b45fc19c762..aab6e25001d 100644 --- a/server/src/main/java/org/elasticsearch/transport/TcpTransportChannel.java +++ b/server/src/main/java/org/elasticsearch/transport/TcpTransportChannel.java @@ -20,6 +20,8 @@ package org.elasticsearch.transport; import org.elasticsearch.Version; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.indices.breaker.CircuitBreakerService; import java.io.IOException; import java.util.Set; @@ -28,38 +30,38 @@ import java.util.concurrent.atomic.AtomicBoolean; public final class TcpTransportChannel implements TransportChannel { private final AtomicBoolean released = new AtomicBoolean(); - private final TcpTransport transport; - private final Version version; - private final Set features; + private final OutboundHandler outboundHandler; + private final TcpChannel channel; private final String action; private final long requestId; - private final String profileName; + private final Version version; + private final Set features; + private final CircuitBreakerService breakerService; private final long reservedBytes; - private final TcpChannel channel; private final boolean compressResponse; - TcpTransportChannel(TcpTransport transport, TcpChannel channel, String action, long requestId, Version version, Set features, - String profileName, long reservedBytes, boolean compressResponse) { + TcpTransportChannel(OutboundHandler outboundHandler, TcpChannel channel, String action, long requestId, Version version, + Set features, CircuitBreakerService breakerService, long reservedBytes, boolean compressResponse) { this.version = version; this.features = features; this.channel = channel; - this.transport = transport; + this.outboundHandler = outboundHandler; this.action = action; this.requestId = requestId; - this.profileName = profileName; + this.breakerService = breakerService; this.reservedBytes = reservedBytes; this.compressResponse = compressResponse; } @Override public String getProfileName() { - return profileName; + return channel.getProfile(); } @Override public void sendResponse(TransportResponse response) throws IOException { try { - transport.sendResponse(version, features, channel, response, requestId, action, compressResponse); + outboundHandler.sendResponse(version, features, channel, requestId, action, response, compressResponse, false); } finally { release(false); } @@ -68,7 +70,7 @@ public final class TcpTransportChannel implements TransportChannel { @Override public void sendResponse(Exception exception) throws IOException { try { - transport.sendErrorResponse(version, features, channel, exception, requestId, action); + outboundHandler.sendErrorResponse(version, features, channel, requestId, action, exception); } finally { release(true); } @@ -79,7 +81,7 @@ public final class TcpTransportChannel implements TransportChannel { private void release(boolean isExceptionResponse) { if (released.compareAndSet(false, true)) { assert (releaseBy = new Exception()) != null; // easier to debug if it's already closed - transport.getInFlightRequestBreaker().addWithoutBreaking(-reservedBytes); + breakerService.getBreaker(CircuitBreaker.IN_FLIGHT_REQUESTS).addWithoutBreaking(-reservedBytes); } else if (isExceptionResponse == false) { // only fail if we are not sending an error - we might send the error triggered by the previous // sendResponse call diff --git a/server/src/main/java/org/elasticsearch/transport/Transport.java b/server/src/main/java/org/elasticsearch/transport/Transport.java index 4357f005df5..eea8ce0f2ff 100644 --- a/server/src/main/java/org/elasticsearch/transport/Transport.java +++ b/server/src/main/java/org/elasticsearch/transport/Transport.java @@ -30,7 +30,6 @@ import org.elasticsearch.common.transport.BoundTransportAddress; import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; import org.elasticsearch.common.util.concurrent.ConcurrentMapLong; - import java.io.Closeable; import java.io.IOException; import java.net.UnknownHostException; diff --git a/server/src/main/java/org/elasticsearch/transport/TransportMessageListener.java b/server/src/main/java/org/elasticsearch/transport/TransportMessageListener.java index bc57c62ca8d..62ff3d8fa43 100644 --- a/server/src/main/java/org/elasticsearch/transport/TransportMessageListener.java +++ b/server/src/main/java/org/elasticsearch/transport/TransportMessageListener.java @@ -22,6 +22,8 @@ import org.elasticsearch.cluster.node.DiscoveryNode; public interface TransportMessageListener { + TransportMessageListener NOOP_LISTENER = new TransportMessageListener() {}; + /** * Called once a request is received * @param requestId the internal request ID diff --git a/server/src/test/java/org/elasticsearch/transport/InboundMessageTests.java b/server/src/test/java/org/elasticsearch/transport/InboundMessageTests.java index 499b6586543..2615a3fdc35 100644 --- a/server/src/test/java/org/elasticsearch/transport/InboundMessageTests.java +++ b/server/src/test/java/org/elasticsearch/transport/InboundMessageTests.java @@ -63,7 +63,7 @@ public class InboundMessageTests extends ESTestCase { InboundMessage.Reader reader = new InboundMessage.Reader(version, registry, threadContext); BytesReference sliced = reference.slice(6, reference.length() - 6); - InboundMessage.RequestMessage inboundMessage = (InboundMessage.RequestMessage) reader.deserialize(sliced); + InboundMessage.Request inboundMessage = (InboundMessage.Request) reader.deserialize(sliced); // Check that deserialize does not overwrite current thread context. assertEquals("header_value2", threadContext.getHeader("header")); inboundMessage.getStoredContext().restore(); @@ -102,7 +102,7 @@ public class InboundMessageTests extends ESTestCase { InboundMessage.Reader reader = new InboundMessage.Reader(version, registry, threadContext); BytesReference sliced = reference.slice(6, reference.length() - 6); - InboundMessage.ResponseMessage inboundMessage = (InboundMessage.ResponseMessage) reader.deserialize(sliced); + InboundMessage.Response inboundMessage = (InboundMessage.Response) reader.deserialize(sliced); // Check that deserialize does not overwrite current thread context. assertEquals("header_value2", threadContext.getHeader("header")); inboundMessage.getStoredContext().restore(); @@ -138,7 +138,7 @@ public class InboundMessageTests extends ESTestCase { InboundMessage.Reader reader = new InboundMessage.Reader(version, registry, threadContext); BytesReference sliced = reference.slice(6, reference.length() - 6); - InboundMessage.ResponseMessage inboundMessage = (InboundMessage.ResponseMessage) reader.deserialize(sliced); + InboundMessage.Response inboundMessage = (InboundMessage.Response) reader.deserialize(sliced); // Check that deserialize does not overwrite current thread context. assertEquals("header_value2", threadContext.getHeader("header")); inboundMessage.getStoredContext().restore(); diff --git a/server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java b/server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java index 01e391a30a7..baab504e61f 100644 --- a/server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java +++ b/server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java @@ -19,14 +19,16 @@ package org.elasticsearch.transport; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.test.ESTestCase; @@ -38,24 +40,34 @@ import org.junit.Before; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.Collections; -import java.util.HashSet; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.instanceOf; + public class OutboundHandlerTests extends ESTestCase { + private final String feature1 = "feature1"; + private final String feature2 = "feature2"; private final TestThreadPool threadPool = new TestThreadPool(getClass().getName()); private final NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); + private final TransportRequestOptions options = TransportRequestOptions.EMPTY; private OutboundHandler handler; - private FakeTcpChannel fakeTcpChannel; + private FakeTcpChannel channel; + private DiscoveryNode node; @Before public void setUp() throws Exception { super.setUp(); TransportLogger transportLogger = new TransportLogger(); - fakeTcpChannel = new FakeTcpChannel(randomBoolean()); - handler = new OutboundHandler(threadPool, BigArrays.NON_RECYCLING_INSTANCE, transportLogger); + channel = new FakeTcpChannel(randomBoolean(), buildNewFakeTransportAddress().address(), buildNewFakeTransportAddress().address()); + TransportAddress transportAddress = buildNewFakeTransportAddress(); + node = new DiscoveryNode("", transportAddress, Version.CURRENT); + String[] features = {feature1, feature2}; + handler = new OutboundHandler("node", Version.CURRENT, features, threadPool, BigArrays.NON_RECYCLING_INSTANCE, transportLogger); } @After @@ -70,10 +82,10 @@ public class OutboundHandlerTests extends ESTestCase { AtomicBoolean isSuccess = new AtomicBoolean(false); AtomicReference exception = new AtomicReference<>(); ActionListener listener = ActionListener.wrap((v) -> isSuccess.set(true), exception::set); - handler.sendBytes(fakeTcpChannel, bytesArray, listener); + handler.sendBytes(channel, bytesArray, listener); - BytesReference reference = fakeTcpChannel.getMessageCaptor().get(); - ActionListener sendListener = fakeTcpChannel.getListenerCaptor().get(); + BytesReference reference = channel.getMessageCaptor().get(); + ActionListener sendListener = channel.getListenerCaptor().get(); if (randomBoolean()) { sendListener.onResponse(null); assertTrue(isSuccess.get()); @@ -88,55 +100,51 @@ public class OutboundHandlerTests extends ESTestCase { assertEquals(bytesArray, reference); } - public void testSendMessage() throws IOException { - OutboundMessage message; + public void testSendRequest() throws IOException { ThreadContext threadContext = threadPool.getThreadContext(); - Version version = Version.CURRENT; - String actionName = "handshake"; + Version version = randomFrom(Version.CURRENT, Version.CURRENT.minimumCompatibilityVersion()); + String action = "handshake"; long requestId = randomLongBetween(0, 300); boolean isHandshake = randomBoolean(); boolean compress = randomBoolean(); String value = "message"; threadContext.putHeader("header", "header_value"); - Writeable writeable = new Message(value); + Request request = new Request(value); - boolean isRequest = randomBoolean(); - if (isRequest) { - message = new OutboundMessage.Request(threadContext, new String[0], writeable, version, actionName, requestId, isHandshake, - compress); - } else { - message = new OutboundMessage.Response(threadContext, new HashSet<>(), writeable, version, requestId, isHandshake, compress); - } + AtomicReference nodeRef = new AtomicReference<>(); + AtomicLong requestIdRef = new AtomicLong(); + AtomicReference actionRef = new AtomicReference<>(); + AtomicReference requestRef = new AtomicReference<>(); + handler.setMessageListener(new TransportMessageListener() { + @Override + public void onRequestSent(DiscoveryNode node, long requestId, String action, TransportRequest request, + TransportRequestOptions options) { + nodeRef.set(node); + requestIdRef.set(requestId); + actionRef.set(action); + requestRef.set(request); + } + }); + handler.sendRequest(node, channel, requestId, action, request, options, version, compress, isHandshake); - AtomicBoolean isSuccess = new AtomicBoolean(false); - AtomicReference exception = new AtomicReference<>(); - ActionListener listener = ActionListener.wrap((v) -> isSuccess.set(true), exception::set); - handler.sendMessage(fakeTcpChannel, message, listener); - - BytesReference reference = fakeTcpChannel.getMessageCaptor().get(); - ActionListener sendListener = fakeTcpChannel.getListenerCaptor().get(); + BytesReference reference = channel.getMessageCaptor().get(); + ActionListener sendListener = channel.getListenerCaptor().get(); if (randomBoolean()) { sendListener.onResponse(null); - assertTrue(isSuccess.get()); - assertNull(exception.get()); } else { - IOException e = new IOException("failed"); - sendListener.onFailure(e); - assertFalse(isSuccess.get()); - assertSame(e, exception.get()); + sendListener.onFailure(new IOException("failed")); } + assertEquals(node, nodeRef.get()); + assertEquals(requestId, requestIdRef.get()); + assertEquals(action, actionRef.get()); + assertEquals(request, requestRef.get()); InboundMessage.Reader reader = new InboundMessage.Reader(Version.CURRENT, namedWriteableRegistry, threadPool.getThreadContext()); try (InboundMessage inboundMessage = reader.deserialize(reference.slice(6, reference.length() - 6))) { assertEquals(version, inboundMessage.getVersion()); assertEquals(requestId, inboundMessage.getRequestId()); - if (isRequest) { - assertTrue(inboundMessage.isRequest()); - assertFalse(inboundMessage.isResponse()); - } else { - assertTrue(inboundMessage.isResponse()); - assertFalse(inboundMessage.isRequest()); - } + assertTrue(inboundMessage.isRequest()); + assertFalse(inboundMessage.isResponse()); if (isHandshake) { assertTrue(inboundMessage.isHandshake()); } else { @@ -147,7 +155,10 @@ public class OutboundHandlerTests extends ESTestCase { } else { assertFalse(inboundMessage.isCompress()); } - Message readMessage = new Message(); + InboundMessage.Request inboundRequest = (InboundMessage.Request) inboundMessage; + assertThat(inboundRequest.getFeatures(), contains(feature1, feature2)); + + Request readMessage = new Request(); readMessage.readFrom(inboundMessage.getStreamInput()); assertEquals(value, readMessage.value); @@ -160,14 +171,163 @@ public class OutboundHandlerTests extends ESTestCase { } } - private static final class Message extends TransportMessage { + public void testSendResponse() throws IOException { + ThreadContext threadContext = threadPool.getThreadContext(); + Version version = randomFrom(Version.CURRENT, Version.CURRENT.minimumCompatibilityVersion()); + String action = "handshake"; + long requestId = randomLongBetween(0, 300); + boolean isHandshake = randomBoolean(); + boolean compress = randomBoolean(); + String value = "message"; + threadContext.putHeader("header", "header_value"); + Response response = new Response(value); + + AtomicLong requestIdRef = new AtomicLong(); + AtomicReference actionRef = new AtomicReference<>(); + AtomicReference responseRef = new AtomicReference<>(); + handler.setMessageListener(new TransportMessageListener() { + @Override + public void onResponseSent(long requestId, String action, TransportResponse response) { + requestIdRef.set(requestId); + actionRef.set(action); + responseRef.set(response); + } + }); + handler.sendResponse(version, Collections.emptySet(), channel, requestId, action, response, compress, isHandshake); + + BytesReference reference = channel.getMessageCaptor().get(); + ActionListener sendListener = channel.getListenerCaptor().get(); + if (randomBoolean()) { + sendListener.onResponse(null); + } else { + sendListener.onFailure(new IOException("failed")); + } + assertEquals(requestId, requestIdRef.get()); + assertEquals(action, actionRef.get()); + assertEquals(response, responseRef.get()); + + InboundMessage.Reader reader = new InboundMessage.Reader(Version.CURRENT, namedWriteableRegistry, threadPool.getThreadContext()); + try (InboundMessage inboundMessage = reader.deserialize(reference.slice(6, reference.length() - 6))) { + assertEquals(version, inboundMessage.getVersion()); + assertEquals(requestId, inboundMessage.getRequestId()); + assertFalse(inboundMessage.isRequest()); + assertTrue(inboundMessage.isResponse()); + if (isHandshake) { + assertTrue(inboundMessage.isHandshake()); + } else { + assertFalse(inboundMessage.isHandshake()); + } + if (compress) { + assertTrue(inboundMessage.isCompress()); + } else { + assertFalse(inboundMessage.isCompress()); + } + + InboundMessage.Response inboundResponse = (InboundMessage.Response) inboundMessage; + assertFalse(inboundResponse.isError()); + + Response readMessage = new Response(); + readMessage.readFrom(inboundMessage.getStreamInput()); + assertEquals(value, readMessage.value); + + try (ThreadContext.StoredContext existing = threadContext.stashContext()) { + ThreadContext.StoredContext storedContext = inboundMessage.getStoredContext(); + assertNull(threadContext.getHeader("header")); + storedContext.restore(); + assertEquals("header_value", threadContext.getHeader("header")); + } + } + } + + public void testErrorResponse() throws IOException { + ThreadContext threadContext = threadPool.getThreadContext(); + Version version = randomFrom(Version.CURRENT, Version.CURRENT.minimumCompatibilityVersion()); + String action = "handshake"; + long requestId = randomLongBetween(0, 300); + threadContext.putHeader("header", "header_value"); + ElasticsearchException error = new ElasticsearchException("boom"); + + AtomicLong requestIdRef = new AtomicLong(); + AtomicReference actionRef = new AtomicReference<>(); + AtomicReference responseRef = new AtomicReference<>(); + handler.setMessageListener(new TransportMessageListener() { + @Override + public void onResponseSent(long requestId, String action, Exception error) { + requestIdRef.set(requestId); + actionRef.set(action); + responseRef.set(error); + } + }); + handler.sendErrorResponse(version, Collections.emptySet(), channel, requestId, action, error); + + BytesReference reference = channel.getMessageCaptor().get(); + ActionListener sendListener = channel.getListenerCaptor().get(); + if (randomBoolean()) { + sendListener.onResponse(null); + } else { + sendListener.onFailure(new IOException("failed")); + } + assertEquals(requestId, requestIdRef.get()); + assertEquals(action, actionRef.get()); + assertEquals(error, responseRef.get()); + + InboundMessage.Reader reader = new InboundMessage.Reader(Version.CURRENT, namedWriteableRegistry, threadPool.getThreadContext()); + try (InboundMessage inboundMessage = reader.deserialize(reference.slice(6, reference.length() - 6))) { + assertEquals(version, inboundMessage.getVersion()); + assertEquals(requestId, inboundMessage.getRequestId()); + assertFalse(inboundMessage.isRequest()); + assertTrue(inboundMessage.isResponse()); + assertFalse(inboundMessage.isCompress()); + assertFalse(inboundMessage.isHandshake()); + + InboundMessage.Response inboundResponse = (InboundMessage.Response) inboundMessage; + assertTrue(inboundResponse.isError()); + + RemoteTransportException remoteException = inboundMessage.getStreamInput().readException(); + assertThat(remoteException.getCause(), instanceOf(ElasticsearchException.class)); + assertEquals(remoteException.getCause().getMessage(), "boom"); + assertEquals(action, remoteException.action()); + assertEquals(channel.getLocalAddress(), remoteException.address().address()); + + try (ThreadContext.StoredContext existing = threadContext.stashContext()) { + ThreadContext.StoredContext storedContext = inboundMessage.getStoredContext(); + assertNull(threadContext.getHeader("header")); + storedContext.restore(); + assertEquals("header_value", threadContext.getHeader("header")); + } + } + } + + private static final class Request extends TransportRequest { public String value; - private Message() { + private Request() { } - private Message(String value) { + private Request(String value) { + this.value = value; + } + + @Override + public void readFrom(StreamInput in) throws IOException { + value = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(value); + } + } + + private static final class Response extends TransportResponse { + + public String value; + + private Response() { + } + + private Response(String value) { this.value = value; } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java index 5c98d1b9dfb..cb03742c297 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java @@ -2008,12 +2008,12 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { new NetworkService(Collections.emptyList()), PageCacheRecycler.NON_RECYCLING_INSTANCE, namedWriteableRegistry, new NoneCircuitBreakerService()) { @Override - protected void handleRequest(TcpChannel channel, InboundMessage.RequestMessage request, int messageLengthBytes) + protected void handleRequest(TcpChannel channel, InboundMessage.Request request, int messageLengthBytes) throws IOException { // we flip the isHandshake bit back and act like the handler is not found byte status = (byte) (request.status & ~(1 << 3)); Version version = request.getVersion(); - InboundMessage.RequestMessage nonHandshakeRequest = new InboundMessage.RequestMessage(request.threadContext, version, + InboundMessage.Request nonHandshakeRequest = new InboundMessage.Request(request.threadContext, version, status, request.getRequestId(), request.getActionName(), request.getFeatures(), request.getStreamInput()); super.handleRequest(channel, nonHandshakeRequest, messageLengthBytes); } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/FakeTcpChannel.java b/test/framework/src/main/java/org/elasticsearch/transport/FakeTcpChannel.java index bb392554305..e9593fc6622 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/FakeTcpChannel.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/FakeTcpChannel.java @@ -44,6 +44,10 @@ public class FakeTcpChannel implements TcpChannel { this(isServer, "profile", new AtomicReference<>()); } + public FakeTcpChannel(boolean isServer, InetSocketAddress localAddress, InetSocketAddress remoteAddress) { + this(isServer, localAddress, remoteAddress, "profile", new AtomicReference<>()); + } + public FakeTcpChannel(boolean isServer, AtomicReference messageCaptor) { this(isServer, "profile", messageCaptor); }