diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/ByteBufStreamInput.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/ByteBufStreamInput.java index 2713f343085..4f0917fd99a 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/ByteBufStreamInput.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/ByteBufStreamInput.java @@ -132,7 +132,13 @@ class ByteBufStreamInput extends StreamInput { @Override public byte readByte() throws IOException { - return buffer.readByte(); + try { + return buffer.readByte(); + } catch (IndexOutOfBoundsException ex) { + EOFException eofException = new EOFException(); + eofException.initCause(ex); + throw eofException; + } } @Override diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/ByteBufUtils.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/ByteBufUtils.java index 8f2908eae85..1f18049514f 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/ByteBufUtils.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/ByteBufUtils.java @@ -238,7 +238,13 @@ class ByteBufUtils { @Override public byte readByte() throws IOException { - return buffer.readByte(); + try { + return buffer.readByte(); + } catch (IndexOutOfBoundsException ex) { + EOFException eofException = new EOFException(); + eofException.initCause(ex); + throw eofException; + } } @Override diff --git a/server/src/main/java/org/elasticsearch/transport/TcpTransport.java b/server/src/main/java/org/elasticsearch/transport/TcpTransport.java index c45525632de..1b408afde06 100644 --- a/server/src/main/java/org/elasticsearch/transport/TcpTransport.java +++ b/server/src/main/java/org/elasticsearch/transport/TcpTransport.java @@ -204,7 +204,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements private volatile Map> requestHandlers = Collections.emptyMap(); private final ResponseHandlers responseHandlers = new ResponseHandlers(); private final TransportLogger transportLogger; - private final TcpTransportHandshaker handshaker; + private final TransportHandshaker handshaker; private final TransportKeepAlive keepAlive; private final String nodeName; @@ -224,12 +224,12 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements this.networkService = networkService; this.transportName = transportName; this.transportLogger = new TransportLogger(); - this.handshaker = new TcpTransportHandshaker(version, threadPool, + this.handshaker = new TransportHandshaker(version, threadPool, (node, channel, requestId, v) -> sendRequestToChannel(node, channel, requestId, - TcpTransportHandshaker.HANDSHAKE_ACTION_NAME, TransportRequest.Empty.INSTANCE, TransportRequestOptions.EMPTY, v, - TransportStatus.setHandshake((byte) 0)), + TransportHandshaker.HANDSHAKE_ACTION_NAME, new TransportHandshaker.HandshakeRequest(version), + TransportRequestOptions.EMPTY, v, TransportStatus.setHandshake((byte) 0)), (v, features, channel, response, requestId) -> sendResponse(v, features, channel, response, requestId, - TcpTransportHandshaker.HANDSHAKE_ACTION_NAME, TransportResponseOptions.EMPTY, TransportStatus.setHandshake((byte) 0))); + TransportHandshaker.HANDSHAKE_ACTION_NAME, TransportResponseOptions.EMPTY, TransportStatus.setHandshake((byte) 0))); this.keepAlive = new TransportKeepAlive(threadPool, this::internalSendMessage); this.nodeName = Node.NODE_NAME_SETTING.get(settings); @@ -1287,7 +1287,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements TransportChannel transportChannel = null; try { if (TransportStatus.isHandshake(status)) { - handshaker.handleHandshake(version, features, channel, requestId); + handshaker.handleHandshake(version, features, channel, requestId, stream); } else { final RequestHandlerRegistry reg = getRequestHandler(action); if (reg == null) { diff --git a/server/src/main/java/org/elasticsearch/transport/TcpTransportHandshaker.java b/server/src/main/java/org/elasticsearch/transport/TransportHandshaker.java similarity index 65% rename from server/src/main/java/org/elasticsearch/transport/TcpTransportHandshaker.java rename to server/src/main/java/org/elasticsearch/transport/TransportHandshaker.java index d1037d2bcb5..3497b29d6d0 100644 --- a/server/src/main/java/org/elasticsearch/transport/TcpTransportHandshaker.java +++ b/server/src/main/java/org/elasticsearch/transport/TransportHandshaker.java @@ -21,12 +21,15 @@ package org.elasticsearch.transport; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.metrics.CounterMetric; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.threadpool.ThreadPool; +import java.io.EOFException; import java.io.IOException; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; @@ -37,7 +40,7 @@ import java.util.concurrent.atomic.AtomicBoolean; * Sends and receives transport-level connection handshakes. This class will send the initial handshake, * manage state/timeouts while the handshake is in transit, and handle the eventual response. */ -final class TcpTransportHandshaker { +final class TransportHandshaker { static final String HANDSHAKE_ACTION_NAME = "internal:tcp/handshake"; private final ConcurrentMap pendingHandshakes = new ConcurrentHashMap<>(); @@ -48,8 +51,8 @@ final class TcpTransportHandshaker { private final HandshakeRequestSender handshakeRequestSender; private final HandshakeResponseSender handshakeResponseSender; - TcpTransportHandshaker(Version version, ThreadPool threadPool, HandshakeRequestSender handshakeRequestSender, - HandshakeResponseSender handshakeResponseSender) { + TransportHandshaker(Version version, ThreadPool threadPool, HandshakeRequestSender handshakeRequestSender, + HandshakeResponseSender handshakeResponseSender) { this.version = version; this.threadPool = threadPool; this.handshakeRequestSender = handshakeRequestSender; @@ -83,11 +86,19 @@ final class TcpTransportHandshaker { } } - void handleHandshake(Version version, Set features, TcpChannel channel, long requestId) throws IOException { - handshakeResponseSender.sendResponse(version, features, channel, new VersionHandshakeResponse(this.version), requestId); + void handleHandshake(Version version, Set features, TcpChannel channel, long requestId, StreamInput stream) throws IOException { + // Must read the handshake request to exhaust the stream + HandshakeRequest handshakeRequest = new HandshakeRequest(stream); + final int nextByte = stream.read(); + if (nextByte != -1) { + throw new IllegalStateException("Handshake request not fully read for requestId [" + requestId + "], action [" + + TransportHandshaker.HANDSHAKE_ACTION_NAME + "], available [" + stream.available() + "]; resetting"); + } + HandshakeResponse response = new HandshakeResponse(this.version); + handshakeResponseSender.sendResponse(version, features, channel, response, requestId); } - TransportResponseHandler removeHandlerForHandshake(long requestId) { + TransportResponseHandler removeHandlerForHandshake(long requestId) { return pendingHandshakes.remove(requestId); } @@ -99,7 +110,7 @@ final class TcpTransportHandshaker { return numHandshakes.count(); } - private class HandshakeResponseHandler implements TransportResponseHandler { + private class HandshakeResponseHandler implements TransportResponseHandler { private final long requestId; private final Version currentVersion; @@ -113,14 +124,14 @@ final class TcpTransportHandshaker { } @Override - public VersionHandshakeResponse read(StreamInput in) throws IOException { - return new VersionHandshakeResponse(in); + public HandshakeResponse read(StreamInput in) throws IOException { + return new HandshakeResponse(in); } @Override - public void handleResponse(VersionHandshakeResponse response) { + public void handleResponse(HandshakeResponse response) { if (isDone.compareAndSet(false, true)) { - Version version = response.version; + Version version = response.responseVersion; if (currentVersion.isCompatible(version) == false) { listener.onFailure(new IllegalStateException("Received message from unsupported version: [" + version + "] minimal compatible version is: [" + currentVersion.minimumCompatibilityVersion() + "]")); @@ -149,24 +160,75 @@ final class TcpTransportHandshaker { } } - static final class VersionHandshakeResponse extends TransportResponse { + static final class HandshakeRequest extends TransportRequest { private final Version version; - VersionHandshakeResponse(Version version) { + HandshakeRequest(Version version) { this.version = version; } - private VersionHandshakeResponse(StreamInput in) throws IOException { + HandshakeRequest(StreamInput streamInput) throws IOException { + super(streamInput); + BytesReference remainingMessage; + try { + remainingMessage = streamInput.readBytesReference(); + } catch (EOFException e) { + remainingMessage = null; + } + if (remainingMessage == null) { + version = null; + } else { + try (StreamInput messageStreamInput = remainingMessage.streamInput()) { + this.version = Version.readVersion(messageStreamInput); + } + } + } + + @Override + public void readFrom(StreamInput in) { + throw new UnsupportedOperationException("usage of Streamable is to be replaced by Writeable"); + } + + @Override + public void writeTo(StreamOutput streamOutput) throws IOException { + super.writeTo(streamOutput); + assert version != null; + try (BytesStreamOutput messageStreamOutput = new BytesStreamOutput(4)) { + Version.writeVersion(version, messageStreamOutput); + BytesReference reference = messageStreamOutput.bytes(); + streamOutput.writeBytesReference(reference); + } + } + } + + static final class HandshakeResponse extends TransportResponse { + + private final Version responseVersion; + + HandshakeResponse(Version responseVersion) { + this.responseVersion = responseVersion; + } + + private HandshakeResponse(StreamInput in) throws IOException { super.readFrom(in); - version = Version.readVersion(in); + responseVersion = Version.readVersion(in); + } + + @Override + public void readFrom(StreamInput in) { + throw new UnsupportedOperationException("usage of Streamable is to be replaced by Writeable"); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - assert version != null; - Version.writeVersion(version, out); + assert responseVersion != null; + Version.writeVersion(responseVersion, out); + } + + Version getResponseVersion() { + return responseVersion; } } diff --git a/server/src/main/java/org/elasticsearch/transport/TransportStatus.java b/server/src/main/java/org/elasticsearch/transport/TransportStatus.java index 2f5f6d6bd9b..0746ed91cfb 100644 --- a/server/src/main/java/org/elasticsearch/transport/TransportStatus.java +++ b/server/src/main/java/org/elasticsearch/transport/TransportStatus.java @@ -66,6 +66,4 @@ public final class TransportStatus { value |= STATUS_HANDSHAKE; return value; } - - } diff --git a/server/src/test/java/org/elasticsearch/transport/TransportHandshakerTests.java b/server/src/test/java/org/elasticsearch/transport/TransportHandshakerTests.java index ec6860f6add..0b5e52009ce 100644 --- a/server/src/test/java/org/elasticsearch/transport/TransportHandshakerTests.java +++ b/server/src/test/java/org/elasticsearch/transport/TransportHandshakerTests.java @@ -21,7 +21,10 @@ package org.elasticsearch.transport; import org.elasticsearch.Version; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.tasks.TaskId; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.TestThreadPool; import org.mockito.ArgumentCaptor; @@ -38,24 +41,24 @@ import static org.mockito.Mockito.verify; public class TransportHandshakerTests extends ESTestCase { - private TcpTransportHandshaker handshaker; + private TransportHandshaker handshaker; private DiscoveryNode node; private TcpChannel channel; private TestThreadPool threadPool; - private TcpTransportHandshaker.HandshakeRequestSender requestSender; - private TcpTransportHandshaker.HandshakeResponseSender responseSender; + private TransportHandshaker.HandshakeRequestSender requestSender; + private TransportHandshaker.HandshakeResponseSender responseSender; @Override public void setUp() throws Exception { super.setUp(); String nodeId = "node-id"; channel = mock(TcpChannel.class); - requestSender = mock(TcpTransportHandshaker.HandshakeRequestSender.class); - responseSender = mock(TcpTransportHandshaker.HandshakeResponseSender.class); + requestSender = mock(TransportHandshaker.HandshakeRequestSender.class); + responseSender = mock(TransportHandshaker.HandshakeResponseSender.class); node = new DiscoveryNode(nodeId, nodeId, nodeId, "host", "host_address", buildNewFakeTransportAddress(), Collections.emptyMap(), Collections.emptySet(), Version.CURRENT); threadPool = new TestThreadPool("thread-poll"); - handshaker = new TcpTransportHandshaker(Version.CURRENT, threadPool, requestSender, responseSender); + handshaker = new TransportHandshaker(Version.CURRENT, threadPool, requestSender, responseSender); } @Override @@ -74,20 +77,63 @@ public class TransportHandshakerTests extends ESTestCase { assertFalse(versionFuture.isDone()); TcpChannel mockChannel = mock(TcpChannel.class); - handshaker.handleHandshake(Version.CURRENT, Collections.emptySet(), mockChannel, reqId); + TransportHandshaker.HandshakeRequest handshakeRequest = new TransportHandshaker.HandshakeRequest(Version.CURRENT); + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + handshakeRequest.writeTo(bytesStreamOutput); + StreamInput input = bytesStreamOutput.bytes().streamInput(); + handshaker.handleHandshake(Version.CURRENT, Collections.emptySet(), mockChannel, reqId, input); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(TransportResponse.class); verify(responseSender).sendResponse(eq(Version.CURRENT), eq(Collections.emptySet()), eq(mockChannel), responseCaptor.capture(), eq(reqId)); - TransportResponseHandler handler = handshaker.removeHandlerForHandshake(reqId); - handler.handleResponse((TcpTransportHandshaker.VersionHandshakeResponse) responseCaptor.getValue()); + TransportResponseHandler handler = handshaker.removeHandlerForHandshake(reqId); + handler.handleResponse((TransportHandshaker.HandshakeResponse) responseCaptor.getValue()); assertTrue(versionFuture.isDone()); assertEquals(Version.CURRENT, versionFuture.actionGet()); } + public void testHandshakeRequestFutureVersionsCompatibility() throws IOException { + long reqId = randomLongBetween(1, 10); + handshaker.sendHandshake(reqId, node, channel, new TimeValue(30, TimeUnit.SECONDS), PlainActionFuture.newFuture()); + + verify(requestSender).sendRequest(node, channel, reqId, Version.CURRENT.minimumCompatibilityVersion()); + + TcpChannel mockChannel = mock(TcpChannel.class); + TransportHandshaker.HandshakeRequest handshakeRequest = new TransportHandshaker.HandshakeRequest(Version.CURRENT); + BytesStreamOutput currentHandshakeBytes = new BytesStreamOutput(); + handshakeRequest.writeTo(currentHandshakeBytes); + + BytesStreamOutput lengthCheckingHandshake = new BytesStreamOutput(); + BytesStreamOutput futureHandshake = new BytesStreamOutput(); + TaskId.EMPTY_TASK_ID.writeTo(lengthCheckingHandshake); + TaskId.EMPTY_TASK_ID.writeTo(futureHandshake); + try (BytesStreamOutput internalMessage = new BytesStreamOutput()) { + Version.writeVersion(Version.CURRENT, internalMessage); + lengthCheckingHandshake.writeBytesReference(internalMessage.bytes()); + internalMessage.write(new byte[1024]); + futureHandshake.writeBytesReference(internalMessage.bytes()); + } + StreamInput futureHandshakeStream = futureHandshake.bytes().streamInput(); + // We check that the handshake we serialize for this test equals the actual request. + // Otherwise, we need to update the test. + assertEquals(currentHandshakeBytes.bytes().length(), lengthCheckingHandshake.bytes().length()); + assertEquals(1031, futureHandshakeStream.available()); + handshaker.handleHandshake(Version.CURRENT, Collections.emptySet(), mockChannel, reqId, futureHandshakeStream); + assertEquals(0, futureHandshakeStream.available()); + + + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(TransportResponse.class); + verify(responseSender).sendResponse(eq(Version.CURRENT), eq(Collections.emptySet()), eq(mockChannel), responseCaptor.capture(), + eq(reqId)); + + TransportHandshaker.HandshakeResponse response = (TransportHandshaker.HandshakeResponse) responseCaptor.getValue(); + + assertEquals(Version.CURRENT, response.getResponseVersion()); + } + public void testHandshakeError() throws IOException { PlainActionFuture versionFuture = PlainActionFuture.newFuture(); long reqId = randomLongBetween(1, 10); @@ -97,7 +143,7 @@ public class TransportHandshakerTests extends ESTestCase { assertFalse(versionFuture.isDone()); - TransportResponseHandler handler = handshaker.removeHandlerForHandshake(reqId); + TransportResponseHandler handler = handshaker.removeHandlerForHandshake(reqId); handler.handleException(new TransportException("failed")); assertTrue(versionFuture.isDone()); @@ -113,7 +159,6 @@ public class TransportHandshakerTests extends ESTestCase { handshaker.sendHandshake(reqId, node, channel, new TimeValue(30, TimeUnit.SECONDS), versionFuture); - assertTrue(versionFuture.isDone()); ConnectTransportException cte = expectThrows(ConnectTransportException.class, versionFuture::actionGet); assertThat(cte.getMessage(), containsString("failure to send internal:tcp/handshake")); diff --git a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java index 6a06d0f72e1..b2e468a9b25 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java @@ -2382,7 +2382,7 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { assertEquals(1, transportStats.getRxCount()); assertEquals(1, transportStats.getTxCount()); assertEquals(25, transportStats.getRxSize().getBytes()); - assertEquals(45, transportStats.getTxSize().getBytes()); + assertEquals(50, transportStats.getTxSize().getBytes()); }); serviceC.sendRequest(connection, "internal:action", new TestRequest("hello world"), TransportRequestOptions.EMPTY, transportResponseHandler); @@ -2392,7 +2392,7 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { assertEquals(1, transportStats.getRxCount()); assertEquals(2, transportStats.getTxCount()); assertEquals(25, transportStats.getRxSize().getBytes()); - assertEquals(101, transportStats.getTxSize().getBytes()); + assertEquals(106, transportStats.getTxSize().getBytes()); }); sendResponseLatch.countDown(); responseLatch.await(); @@ -2400,7 +2400,7 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { assertEquals(2, stats.getRxCount()); assertEquals(2, stats.getTxCount()); assertEquals(46, stats.getRxSize().getBytes()); - assertEquals(101, stats.getTxSize().getBytes()); + assertEquals(106, stats.getTxSize().getBytes()); } finally { serviceC.close(); } @@ -2497,7 +2497,7 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { assertEquals(1, transportStats.getRxCount()); assertEquals(1, transportStats.getTxCount()); assertEquals(25, transportStats.getRxSize().getBytes()); - assertEquals(45, transportStats.getTxSize().getBytes()); + assertEquals(50, transportStats.getTxSize().getBytes()); }); serviceC.sendRequest(connection, "internal:action", new TestRequest("hello world"), TransportRequestOptions.EMPTY, transportResponseHandler); @@ -2507,7 +2507,7 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { assertEquals(1, transportStats.getRxCount()); assertEquals(2, transportStats.getTxCount()); assertEquals(25, transportStats.getRxSize().getBytes()); - assertEquals(101, transportStats.getTxSize().getBytes()); + assertEquals(106, transportStats.getTxSize().getBytes()); }); sendResponseLatch.countDown(); responseLatch.await(); @@ -2522,7 +2522,7 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { // 49 bytes are the non-exception message bytes that have been received. It should include the initial // handshake message and the header, version, etc bytes in the exception message. assertEquals(failedMessage, 49 + streamOutput.bytes().length(), stats.getRxSize().getBytes()); - assertEquals(101, stats.getTxSize().getBytes()); + assertEquals(106, stats.getTxSize().getBytes()); } finally { serviceC.close(); }