diff --git a/server/src/main/java/org/elasticsearch/transport/Header.java b/server/src/main/java/org/elasticsearch/transport/Header.java index 86cbd044114..41fa2afb839 100644 --- a/server/src/main/java/org/elasticsearch/transport/Header.java +++ b/server/src/main/java/org/elasticsearch/transport/Header.java @@ -42,7 +42,7 @@ public class Header { // These are directly set by tests String actionName; Tuple, Map>> headers; - private Set features; + Set features; Header(int networkMessageSize, long requestId, byte status, Version version) { this.networkMessageSize = networkMessageSize; diff --git a/server/src/main/java/org/elasticsearch/transport/InboundDecoder.java b/server/src/main/java/org/elasticsearch/transport/InboundDecoder.java index ef35251ffe8..02523185220 100644 --- a/server/src/main/java/org/elasticsearch/transport/InboundDecoder.java +++ b/server/src/main/java/org/elasticsearch/transport/InboundDecoder.java @@ -72,7 +72,7 @@ public class InboundDecoder implements Releasable { } else { totalNetworkSize = messageLength + TcpHeader.BYTES_REQUIRED_FOR_MESSAGE_SIZE; - Header header = readHeader(messageLength, reference); + Header header = readHeader(version, messageLength, reference); bytesConsumed += headerBytesToRead; if (header.isCompressed()) { decompressor = new TransportDecompressor(recycler); @@ -166,7 +166,8 @@ public class InboundDecoder implements Releasable { } } - private Header readHeader(int networkMessageSize, BytesReference bytesReference) throws IOException { + // exposed for use in tests + static Header readHeader(Version version, int networkMessageSize, BytesReference bytesReference) throws IOException { try (StreamInput streamInput = bytesReference.streamInput()) { streamInput.skip(TcpHeader.BYTES_REQUIRED_FOR_MESSAGE_SIZE); long requestId = streamInput.readLong(); diff --git a/server/src/main/java/org/elasticsearch/transport/InboundHandler.java b/server/src/main/java/org/elasticsearch/transport/InboundHandler.java index 4943e7d02c2..a071a26e391 100644 --- a/server/src/main/java/org/elasticsearch/transport/InboundHandler.java +++ b/server/src/main/java/org/elasticsearch/transport/InboundHandler.java @@ -143,7 +143,14 @@ public class InboundHandler { try { handshaker.handleHandshake(transportChannel, requestId, stream); } catch (Exception e) { - sendErrorResponse(action, transportChannel, e); + if (Version.CURRENT.isCompatible(header.getVersion())) { + sendErrorResponse(action, transportChannel, e); + } else { + logger.warn(new ParameterizedMessage( + "could not send error response to handshake received on [{}] using wire format version [{}], closing channel", + channel, header.getVersion()), e); + channel.close(); + } } } else { final TransportChannel transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, diff --git a/server/src/test/java/org/elasticsearch/transport/InboundHandlerTests.java b/server/src/test/java/org/elasticsearch/transport/InboundHandlerTests.java index 5fd70a9d13b..ee99c7bb0d0 100644 --- a/server/src/test/java/org/elasticsearch/transport/InboundHandlerTests.java +++ b/server/src/test/java/org/elasticsearch/transport/InboundHandlerTests.java @@ -19,25 +19,37 @@ package org.elasticsearch.transport; +import org.apache.logging.log4j.Level; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.InputStreamStreamInput; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.logging.Loggers; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.MockLogAppender; +import org.elasticsearch.test.VersionUtils; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.junit.After; import org.junit.Before; import java.io.IOException; +import java.io.InputStream; import java.util.Collections; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import static org.hamcrest.Matchers.instanceOf; @@ -165,4 +177,84 @@ public class InboundHandlerTests extends ESTestCase { assertEquals(responseValue, responseCaptor.get().value); } } + + public void testSendsErrorResponseToHandshakeFromCompatibleVersion() throws Exception { + // Nodes use their minimum compatibility version for the TCP handshake, so a node from v(major-1).x will report its version as + // v(major-2).last in the TCP handshake, with which we are not really compatible. We put extra effort into making sure that if + // successful we can respond correctly in a format this old, but we do not guarantee that we can respond correctly with an error + // response. However if the two nodes are from the same major version then we do guarantee compatibility of error responses. + + final Version remoteVersion = VersionUtils.randomCompatibleVersion(random(), version); + final long requestId = randomNonNegativeLong(); + final Header requestHeader = new Header(between(0, 100), requestId, + TransportStatus.setRequest(TransportStatus.setHandshake((byte) 0)), remoteVersion); + final InboundMessage requestMessage = unreadableInboundHandshake(remoteVersion, requestHeader); + requestHeader.actionName = TransportHandshaker.HANDSHAKE_ACTION_NAME; + requestHeader.headers = Tuple.tuple(org.elasticsearch.common.collect.Map.of(), org.elasticsearch.common.collect.Map.of()); + requestHeader.features = org.elasticsearch.common.collect.Set.of(); + handler.inboundMessage(channel, requestMessage); + + final BytesReference responseBytesReference = channel.getMessageCaptor().get(); + final Header responseHeader = InboundDecoder.readHeader(remoteVersion, responseBytesReference.length(), responseBytesReference); + assertTrue(responseHeader.isResponse()); + assertTrue(responseHeader.isError()); + } + + + public void testClosesChannelOnErrorInHandshakeWithIncompatibleVersion() throws Exception { + // Nodes use their minimum compatibility version for the TCP handshake, so a node from v(major-1).x will report its version as + // v(major-2).last in the TCP handshake, with which we are not really compatible. We put extra effort into making sure that if + // successful we can respond correctly in a format this old, but we do not guarantee that we can respond correctly with an error + // response so we must just close the connection on an error. To avoid the failure disappearing into a black hole we at least log + // it. + + final MockLogAppender mockAppender = new MockLogAppender(); + mockAppender.start(); + mockAppender.addExpectation( + new MockLogAppender.SeenEventExpectation( + "expected message", + InboundHandler.class.getCanonicalName(), + Level.WARN, + "could not send error response to handshake")); + final Logger inboundHandlerLogger = LogManager.getLogger(InboundHandler.class); + Loggers.addAppender(inboundHandlerLogger, mockAppender); + + try { + final AtomicBoolean isClosed = new AtomicBoolean(); + channel.addCloseListener(ActionListener.wrap(() -> assertTrue(isClosed.compareAndSet(false, true)))); + + final Version remoteVersion = Version.fromId(randomIntBetween(0, version.minimumCompatibilityVersion().id - 1)); + final long requestId = randomNonNegativeLong(); + final Header requestHeader = new Header(between(0, 100), requestId, + TransportStatus.setRequest(TransportStatus.setHandshake((byte) 0)), remoteVersion); + final InboundMessage requestMessage = unreadableInboundHandshake(remoteVersion, requestHeader); + requestHeader.actionName = TransportHandshaker.HANDSHAKE_ACTION_NAME; + requestHeader.headers = Tuple.tuple(org.elasticsearch.common.collect.Map.of(), org.elasticsearch.common.collect.Map.of()); + requestHeader.features = org.elasticsearch.common.collect.Set.of(); + handler.inboundMessage(channel, requestMessage); + assertTrue(isClosed.get()); + assertNull(channel.getMessageCaptor().get()); + mockAppender.assertAllExpectationsMatched(); + } finally { + Loggers.removeAppender(inboundHandlerLogger, mockAppender); + mockAppender.stop(); + } + } + + private static InboundMessage unreadableInboundHandshake(Version remoteVersion, Header requestHeader) { + return new InboundMessage(requestHeader, ReleasableBytesReference.wrap(BytesArray.EMPTY), () -> { }) { + @Override + public StreamInput openOrGetStreamInput() { + final StreamInput streamInput = new InputStreamStreamInput(new InputStream() { + @Override + public int read() { + throw new ElasticsearchException("unreadable handshake"); + } + }); + streamInput.setVersion(remoteVersion); + return streamInput; + } + }; + } + }