Add version to handshake requests (#36171)

Currently our handshake requests do not include a version. This is
unfortunate as we cannot rely on the stream version since it is not the
sending node's version. Instead it is the minimum compatibility version.
The handshake request is currently empty and we do nothing with it. This
should allow us to add data to the request without breaking backwards
compatibility.

This commit adds the version to the handshake request. Additionally, it
allows "future data" to be added to the request. This allows nodes to craft
a version compatible response. And will properly handle additional data in
future handshake requests. The proper handling of "future data" is useful
as this is the only request where we do not know the other node's version.

Finally, it renames the TcpTransportHandshaker to
TransportHandshaker.
This commit is contained in:
Tim Brooks 2018-12-11 16:09:28 -07:00 committed by GitHub
parent 55743aac47
commit 797f985067
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 161 additions and 44 deletions

View File

@ -132,7 +132,13 @@ class ByteBufStreamInput extends StreamInput {
@Override
public byte readByte() throws IOException {
return buffer.readByte();
try {
return buffer.readByte();
} catch (IndexOutOfBoundsException ex) {
EOFException eofException = new EOFException();
eofException.initCause(ex);
throw eofException;
}
}
@Override

View File

@ -238,7 +238,13 @@ class ByteBufUtils {
@Override
public byte readByte() throws IOException {
return buffer.readByte();
try {
return buffer.readByte();
} catch (IndexOutOfBoundsException ex) {
EOFException eofException = new EOFException();
eofException.initCause(ex);
throw eofException;
}
}
@Override

View File

@ -204,7 +204,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
private volatile Map<String, RequestHandlerRegistry<? extends TransportRequest>> requestHandlers = Collections.emptyMap();
private final ResponseHandlers responseHandlers = new ResponseHandlers();
private final TransportLogger transportLogger;
private final TcpTransportHandshaker handshaker;
private final TransportHandshaker handshaker;
private final TransportKeepAlive keepAlive;
private final String nodeName;
@ -224,12 +224,12 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
this.networkService = networkService;
this.transportName = transportName;
this.transportLogger = new TransportLogger();
this.handshaker = new TcpTransportHandshaker(version, threadPool,
this.handshaker = new TransportHandshaker(version, threadPool,
(node, channel, requestId, v) -> sendRequestToChannel(node, channel, requestId,
TcpTransportHandshaker.HANDSHAKE_ACTION_NAME, TransportRequest.Empty.INSTANCE, TransportRequestOptions.EMPTY, v,
TransportStatus.setHandshake((byte) 0)),
TransportHandshaker.HANDSHAKE_ACTION_NAME, new TransportHandshaker.HandshakeRequest(version),
TransportRequestOptions.EMPTY, v, TransportStatus.setHandshake((byte) 0)),
(v, features, channel, response, requestId) -> sendResponse(v, features, channel, response, requestId,
TcpTransportHandshaker.HANDSHAKE_ACTION_NAME, TransportResponseOptions.EMPTY, TransportStatus.setHandshake((byte) 0)));
TransportHandshaker.HANDSHAKE_ACTION_NAME, TransportResponseOptions.EMPTY, TransportStatus.setHandshake((byte) 0)));
this.keepAlive = new TransportKeepAlive(threadPool, this::internalSendMessage);
this.nodeName = Node.NODE_NAME_SETTING.get(settings);
@ -1287,7 +1287,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
TransportChannel transportChannel = null;
try {
if (TransportStatus.isHandshake(status)) {
handshaker.handleHandshake(version, features, channel, requestId);
handshaker.handleHandshake(version, features, channel, requestId, stream);
} else {
final RequestHandlerRegistry reg = getRequestHandler(action);
if (reg == null) {

View File

@ -21,12 +21,15 @@ package org.elasticsearch.transport;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.metrics.CounterMetric;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.threadpool.ThreadPool;
import java.io.EOFException;
import java.io.IOException;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
@ -37,7 +40,7 @@ import java.util.concurrent.atomic.AtomicBoolean;
* Sends and receives transport-level connection handshakes. This class will send the initial handshake,
* manage state/timeouts while the handshake is in transit, and handle the eventual response.
*/
final class TcpTransportHandshaker {
final class TransportHandshaker {
static final String HANDSHAKE_ACTION_NAME = "internal:tcp/handshake";
private final ConcurrentMap<Long, HandshakeResponseHandler> pendingHandshakes = new ConcurrentHashMap<>();
@ -48,8 +51,8 @@ final class TcpTransportHandshaker {
private final HandshakeRequestSender handshakeRequestSender;
private final HandshakeResponseSender handshakeResponseSender;
TcpTransportHandshaker(Version version, ThreadPool threadPool, HandshakeRequestSender handshakeRequestSender,
HandshakeResponseSender handshakeResponseSender) {
TransportHandshaker(Version version, ThreadPool threadPool, HandshakeRequestSender handshakeRequestSender,
HandshakeResponseSender handshakeResponseSender) {
this.version = version;
this.threadPool = threadPool;
this.handshakeRequestSender = handshakeRequestSender;
@ -83,11 +86,19 @@ final class TcpTransportHandshaker {
}
}
void handleHandshake(Version version, Set<String> features, TcpChannel channel, long requestId) throws IOException {
handshakeResponseSender.sendResponse(version, features, channel, new VersionHandshakeResponse(this.version), requestId);
void handleHandshake(Version version, Set<String> features, TcpChannel channel, long requestId, StreamInput stream) throws IOException {
// Must read the handshake request to exhaust the stream
HandshakeRequest handshakeRequest = new HandshakeRequest(stream);
final int nextByte = stream.read();
if (nextByte != -1) {
throw new IllegalStateException("Handshake request not fully read for requestId [" + requestId + "], action ["
+ TransportHandshaker.HANDSHAKE_ACTION_NAME + "], available [" + stream.available() + "]; resetting");
}
HandshakeResponse response = new HandshakeResponse(this.version);
handshakeResponseSender.sendResponse(version, features, channel, response, requestId);
}
TransportResponseHandler<VersionHandshakeResponse> removeHandlerForHandshake(long requestId) {
TransportResponseHandler<HandshakeResponse> removeHandlerForHandshake(long requestId) {
return pendingHandshakes.remove(requestId);
}
@ -99,7 +110,7 @@ final class TcpTransportHandshaker {
return numHandshakes.count();
}
private class HandshakeResponseHandler implements TransportResponseHandler<VersionHandshakeResponse> {
private class HandshakeResponseHandler implements TransportResponseHandler<HandshakeResponse> {
private final long requestId;
private final Version currentVersion;
@ -113,14 +124,14 @@ final class TcpTransportHandshaker {
}
@Override
public VersionHandshakeResponse read(StreamInput in) throws IOException {
return new VersionHandshakeResponse(in);
public HandshakeResponse read(StreamInput in) throws IOException {
return new HandshakeResponse(in);
}
@Override
public void handleResponse(VersionHandshakeResponse response) {
public void handleResponse(HandshakeResponse response) {
if (isDone.compareAndSet(false, true)) {
Version version = response.version;
Version version = response.responseVersion;
if (currentVersion.isCompatible(version) == false) {
listener.onFailure(new IllegalStateException("Received message from unsupported version: [" + version
+ "] minimal compatible version is: [" + currentVersion.minimumCompatibilityVersion() + "]"));
@ -149,24 +160,75 @@ final class TcpTransportHandshaker {
}
}
static final class VersionHandshakeResponse extends TransportResponse {
static final class HandshakeRequest extends TransportRequest {
private final Version version;
VersionHandshakeResponse(Version version) {
HandshakeRequest(Version version) {
this.version = version;
}
private VersionHandshakeResponse(StreamInput in) throws IOException {
HandshakeRequest(StreamInput streamInput) throws IOException {
super(streamInput);
BytesReference remainingMessage;
try {
remainingMessage = streamInput.readBytesReference();
} catch (EOFException e) {
remainingMessage = null;
}
if (remainingMessage == null) {
version = null;
} else {
try (StreamInput messageStreamInput = remainingMessage.streamInput()) {
this.version = Version.readVersion(messageStreamInput);
}
}
}
@Override
public void readFrom(StreamInput in) {
throw new UnsupportedOperationException("usage of Streamable is to be replaced by Writeable");
}
@Override
public void writeTo(StreamOutput streamOutput) throws IOException {
super.writeTo(streamOutput);
assert version != null;
try (BytesStreamOutput messageStreamOutput = new BytesStreamOutput(4)) {
Version.writeVersion(version, messageStreamOutput);
BytesReference reference = messageStreamOutput.bytes();
streamOutput.writeBytesReference(reference);
}
}
}
static final class HandshakeResponse extends TransportResponse {
private final Version responseVersion;
HandshakeResponse(Version responseVersion) {
this.responseVersion = responseVersion;
}
private HandshakeResponse(StreamInput in) throws IOException {
super.readFrom(in);
version = Version.readVersion(in);
responseVersion = Version.readVersion(in);
}
@Override
public void readFrom(StreamInput in) {
throw new UnsupportedOperationException("usage of Streamable is to be replaced by Writeable");
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
assert version != null;
Version.writeVersion(version, out);
assert responseVersion != null;
Version.writeVersion(responseVersion, out);
}
Version getResponseVersion() {
return responseVersion;
}
}

View File

@ -66,6 +66,4 @@ public final class TransportStatus {
value |= STATUS_HANDSHAKE;
return value;
}
}

View File

@ -21,7 +21,10 @@ package org.elasticsearch.transport;
import org.elasticsearch.Version;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
import org.mockito.ArgumentCaptor;
@ -38,24 +41,24 @@ import static org.mockito.Mockito.verify;
public class TransportHandshakerTests extends ESTestCase {
private TcpTransportHandshaker handshaker;
private TransportHandshaker handshaker;
private DiscoveryNode node;
private TcpChannel channel;
private TestThreadPool threadPool;
private TcpTransportHandshaker.HandshakeRequestSender requestSender;
private TcpTransportHandshaker.HandshakeResponseSender responseSender;
private TransportHandshaker.HandshakeRequestSender requestSender;
private TransportHandshaker.HandshakeResponseSender responseSender;
@Override
public void setUp() throws Exception {
super.setUp();
String nodeId = "node-id";
channel = mock(TcpChannel.class);
requestSender = mock(TcpTransportHandshaker.HandshakeRequestSender.class);
responseSender = mock(TcpTransportHandshaker.HandshakeResponseSender.class);
requestSender = mock(TransportHandshaker.HandshakeRequestSender.class);
responseSender = mock(TransportHandshaker.HandshakeResponseSender.class);
node = new DiscoveryNode(nodeId, nodeId, nodeId, "host", "host_address", buildNewFakeTransportAddress(), Collections.emptyMap(),
Collections.emptySet(), Version.CURRENT);
threadPool = new TestThreadPool("thread-poll");
handshaker = new TcpTransportHandshaker(Version.CURRENT, threadPool, requestSender, responseSender);
handshaker = new TransportHandshaker(Version.CURRENT, threadPool, requestSender, responseSender);
}
@Override
@ -74,20 +77,63 @@ public class TransportHandshakerTests extends ESTestCase {
assertFalse(versionFuture.isDone());
TcpChannel mockChannel = mock(TcpChannel.class);
handshaker.handleHandshake(Version.CURRENT, Collections.emptySet(), mockChannel, reqId);
TransportHandshaker.HandshakeRequest handshakeRequest = new TransportHandshaker.HandshakeRequest(Version.CURRENT);
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
handshakeRequest.writeTo(bytesStreamOutput);
StreamInput input = bytesStreamOutput.bytes().streamInput();
handshaker.handleHandshake(Version.CURRENT, Collections.emptySet(), mockChannel, reqId, input);
ArgumentCaptor<TransportResponse> responseCaptor = ArgumentCaptor.forClass(TransportResponse.class);
verify(responseSender).sendResponse(eq(Version.CURRENT), eq(Collections.emptySet()), eq(mockChannel), responseCaptor.capture(),
eq(reqId));
TransportResponseHandler<TcpTransportHandshaker.VersionHandshakeResponse> handler = handshaker.removeHandlerForHandshake(reqId);
handler.handleResponse((TcpTransportHandshaker.VersionHandshakeResponse) responseCaptor.getValue());
TransportResponseHandler<TransportHandshaker.HandshakeResponse> handler = handshaker.removeHandlerForHandshake(reqId);
handler.handleResponse((TransportHandshaker.HandshakeResponse) responseCaptor.getValue());
assertTrue(versionFuture.isDone());
assertEquals(Version.CURRENT, versionFuture.actionGet());
}
public void testHandshakeRequestFutureVersionsCompatibility() throws IOException {
long reqId = randomLongBetween(1, 10);
handshaker.sendHandshake(reqId, node, channel, new TimeValue(30, TimeUnit.SECONDS), PlainActionFuture.newFuture());
verify(requestSender).sendRequest(node, channel, reqId, Version.CURRENT.minimumCompatibilityVersion());
TcpChannel mockChannel = mock(TcpChannel.class);
TransportHandshaker.HandshakeRequest handshakeRequest = new TransportHandshaker.HandshakeRequest(Version.CURRENT);
BytesStreamOutput currentHandshakeBytes = new BytesStreamOutput();
handshakeRequest.writeTo(currentHandshakeBytes);
BytesStreamOutput lengthCheckingHandshake = new BytesStreamOutput();
BytesStreamOutput futureHandshake = new BytesStreamOutput();
TaskId.EMPTY_TASK_ID.writeTo(lengthCheckingHandshake);
TaskId.EMPTY_TASK_ID.writeTo(futureHandshake);
try (BytesStreamOutput internalMessage = new BytesStreamOutput()) {
Version.writeVersion(Version.CURRENT, internalMessage);
lengthCheckingHandshake.writeBytesReference(internalMessage.bytes());
internalMessage.write(new byte[1024]);
futureHandshake.writeBytesReference(internalMessage.bytes());
}
StreamInput futureHandshakeStream = futureHandshake.bytes().streamInput();
// We check that the handshake we serialize for this test equals the actual request.
// Otherwise, we need to update the test.
assertEquals(currentHandshakeBytes.bytes().length(), lengthCheckingHandshake.bytes().length());
assertEquals(1031, futureHandshakeStream.available());
handshaker.handleHandshake(Version.CURRENT, Collections.emptySet(), mockChannel, reqId, futureHandshakeStream);
assertEquals(0, futureHandshakeStream.available());
ArgumentCaptor<TransportResponse> responseCaptor = ArgumentCaptor.forClass(TransportResponse.class);
verify(responseSender).sendResponse(eq(Version.CURRENT), eq(Collections.emptySet()), eq(mockChannel), responseCaptor.capture(),
eq(reqId));
TransportHandshaker.HandshakeResponse response = (TransportHandshaker.HandshakeResponse) responseCaptor.getValue();
assertEquals(Version.CURRENT, response.getResponseVersion());
}
public void testHandshakeError() throws IOException {
PlainActionFuture<Version> versionFuture = PlainActionFuture.newFuture();
long reqId = randomLongBetween(1, 10);
@ -97,7 +143,7 @@ public class TransportHandshakerTests extends ESTestCase {
assertFalse(versionFuture.isDone());
TransportResponseHandler<TcpTransportHandshaker.VersionHandshakeResponse> handler = handshaker.removeHandlerForHandshake(reqId);
TransportResponseHandler<TransportHandshaker.HandshakeResponse> handler = handshaker.removeHandlerForHandshake(reqId);
handler.handleException(new TransportException("failed"));
assertTrue(versionFuture.isDone());
@ -113,7 +159,6 @@ public class TransportHandshakerTests extends ESTestCase {
handshaker.sendHandshake(reqId, node, channel, new TimeValue(30, TimeUnit.SECONDS), versionFuture);
assertTrue(versionFuture.isDone());
ConnectTransportException cte = expectThrows(ConnectTransportException.class, versionFuture::actionGet);
assertThat(cte.getMessage(), containsString("failure to send internal:tcp/handshake"));

View File

@ -2382,7 +2382,7 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
assertEquals(1, transportStats.getRxCount());
assertEquals(1, transportStats.getTxCount());
assertEquals(25, transportStats.getRxSize().getBytes());
assertEquals(45, transportStats.getTxSize().getBytes());
assertEquals(50, transportStats.getTxSize().getBytes());
});
serviceC.sendRequest(connection, "internal:action", new TestRequest("hello world"), TransportRequestOptions.EMPTY,
transportResponseHandler);
@ -2392,7 +2392,7 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
assertEquals(1, transportStats.getRxCount());
assertEquals(2, transportStats.getTxCount());
assertEquals(25, transportStats.getRxSize().getBytes());
assertEquals(101, transportStats.getTxSize().getBytes());
assertEquals(106, transportStats.getTxSize().getBytes());
});
sendResponseLatch.countDown();
responseLatch.await();
@ -2400,7 +2400,7 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
assertEquals(2, stats.getRxCount());
assertEquals(2, stats.getTxCount());
assertEquals(46, stats.getRxSize().getBytes());
assertEquals(101, stats.getTxSize().getBytes());
assertEquals(106, stats.getTxSize().getBytes());
} finally {
serviceC.close();
}
@ -2497,7 +2497,7 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
assertEquals(1, transportStats.getRxCount());
assertEquals(1, transportStats.getTxCount());
assertEquals(25, transportStats.getRxSize().getBytes());
assertEquals(45, transportStats.getTxSize().getBytes());
assertEquals(50, transportStats.getTxSize().getBytes());
});
serviceC.sendRequest(connection, "internal:action", new TestRequest("hello world"), TransportRequestOptions.EMPTY,
transportResponseHandler);
@ -2507,7 +2507,7 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
assertEquals(1, transportStats.getRxCount());
assertEquals(2, transportStats.getTxCount());
assertEquals(25, transportStats.getRxSize().getBytes());
assertEquals(101, transportStats.getTxSize().getBytes());
assertEquals(106, transportStats.getTxSize().getBytes());
});
sendResponseLatch.countDown();
responseLatch.await();
@ -2522,7 +2522,7 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
// 49 bytes are the non-exception message bytes that have been received. It should include the initial
// handshake message and the header, version, etc bytes in the exception message.
assertEquals(failedMessage, 49 + streamOutput.bytes().length(), stats.getRxSize().getBytes());
assertEquals(101, stats.getTxSize().getBytes());
assertEquals(106, stats.getTxSize().getBytes());
} finally {
serviceC.close();
}