Use TransportChannel in TransportHandshaker (#54921)

Currently the TransportHandshaker has a specialized codepath for sending
a response. In other work, we are going to start having handshakes
contribute to circuit breaking (while not being breakable). This commit
moves in that direction by allowing the handshaker to responding using a
standard TcpTransportChannel similar to other requests.
This commit is contained in:
Tim Brooks 2020-04-07 15:37:15 -06:00 committed by GitHub
parent ce7ae4a7d1
commit c7053ef824
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 109 additions and 150 deletions

View File

@ -157,7 +157,10 @@ public class InboundHandler {
try { try {
messageListener.onRequestReceived(requestId, action); messageListener.onRequestReceived(requestId, action);
if (header.isHandshake()) { if (header.isHandshake()) {
handshaker.handleHandshake(version, features, channel, requestId, stream); // Handshakes are not currently circuit broken
transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, features,
circuitBreakerService, 0, header.isCompressed(), header.isHandshake());
handshaker.handleHandshake(transportChannel, requestId, stream);
} else { } else {
final RequestHandlerRegistry<T> reg = getRequestHandler(action); final RequestHandlerRegistry<T> reg = getRequestHandler(action);
if (reg == null) { if (reg == null) {
@ -170,7 +173,7 @@ public class InboundHandler {
breaker.addWithoutBreaking(messageLengthBytes); breaker.addWithoutBreaking(messageLengthBytes);
} }
transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, features, transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, features,
circuitBreakerService, messageLengthBytes, header.isCompressed()); circuitBreakerService, messageLengthBytes, header.isCompressed(), header.isHandshake());
final T request = reg.newRequest(stream); final T request = reg.newRequest(stream);
request.remoteAddress(new TransportAddress(channel.getRemoteAddress())); request.remoteAddress(new TransportAddress(channel.getRemoteAddress()));
// in case we throw an exception, i.e. when the limit is hit, we don't want to verify // in case we throw an exception, i.e. when the limit is hit, we don't want to verify
@ -186,7 +189,7 @@ public class InboundHandler {
// the circuit breaker tripped // the circuit breaker tripped
if (transportChannel == null) { if (transportChannel == null) {
transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, features, transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, features,
circuitBreakerService, 0, header.isCompressed()); circuitBreakerService, 0, header.isCompressed(), header.isHandshake());
} }
try { try {
transportChannel.sendResponse(e); transportChannel.sendResponse(e);

View File

@ -159,9 +159,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
this.handshaker = new TransportHandshaker(version, threadPool, this.handshaker = new TransportHandshaker(version, threadPool,
(node, channel, requestId, v) -> outboundHandler.sendRequest(node, channel, requestId, (node, channel, requestId, v) -> outboundHandler.sendRequest(node, channel, requestId,
TransportHandshaker.HANDSHAKE_ACTION_NAME, new TransportHandshaker.HandshakeRequest(version), TransportHandshaker.HANDSHAKE_ACTION_NAME, new TransportHandshaker.HandshakeRequest(version),
TransportRequestOptions.EMPTY, v, false, true), TransportRequestOptions.EMPTY, v, false, true));
(v, features1, channel, response, requestId) -> outboundHandler.sendResponse(v, features1, channel, requestId,
TransportHandshaker.HANDSHAKE_ACTION_NAME, response, false, true));
this.keepAlive = new TransportKeepAlive(threadPool, this.outboundHandler::sendBytes); this.keepAlive = new TransportKeepAlive(threadPool, this.outboundHandler::sendBytes);
this.inboundHandler = new InboundHandler(threadPool, outboundHandler, namedWriteableRegistry, circuitBreakerService, handshaker, this.inboundHandler = new InboundHandler(threadPool, outboundHandler, namedWriteableRegistry, circuitBreakerService, handshaker,
keepAlive); keepAlive);

View File

@ -39,9 +39,11 @@ public final class TcpTransportChannel implements TransportChannel {
private final CircuitBreakerService breakerService; private final CircuitBreakerService breakerService;
private final long reservedBytes; private final long reservedBytes;
private final boolean compressResponse; private final boolean compressResponse;
private final boolean isHandshake;
TcpTransportChannel(OutboundHandler outboundHandler, TcpChannel channel, String action, long requestId, Version version, TcpTransportChannel(OutboundHandler outboundHandler, TcpChannel channel, String action, long requestId, Version version,
Set<String> features, CircuitBreakerService breakerService, long reservedBytes, boolean compressResponse) { Set<String> features, CircuitBreakerService breakerService, long reservedBytes, boolean compressResponse,
boolean isHandshake) {
this.version = version; this.version = version;
this.features = features; this.features = features;
this.channel = channel; this.channel = channel;
@ -51,6 +53,7 @@ public final class TcpTransportChannel implements TransportChannel {
this.breakerService = breakerService; this.breakerService = breakerService;
this.reservedBytes = reservedBytes; this.reservedBytes = reservedBytes;
this.compressResponse = compressResponse; this.compressResponse = compressResponse;
this.isHandshake = isHandshake;
} }
@Override @Override
@ -61,7 +64,7 @@ public final class TcpTransportChannel implements TransportChannel {
@Override @Override
public void sendResponse(TransportResponse response) throws IOException { public void sendResponse(TransportResponse response) throws IOException {
try { try {
outboundHandler.sendResponse(version, features, channel, requestId, action, response, compressResponse, false); outboundHandler.sendResponse(version, features, channel, requestId, action, response, compressResponse, isHandshake);
} finally { } finally {
release(false); release(false);
} }
@ -102,6 +105,5 @@ public final class TcpTransportChannel implements TransportChannel {
public TcpChannel getChannel() { public TcpChannel getChannel() {
return channel; return channel;
} }
} }

View File

@ -31,7 +31,6 @@ import org.elasticsearch.threadpool.ThreadPool;
import java.io.EOFException; import java.io.EOFException;
import java.io.IOException; import java.io.IOException;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
@ -49,14 +48,11 @@ final class TransportHandshaker {
private final Version version; private final Version version;
private final ThreadPool threadPool; private final ThreadPool threadPool;
private final HandshakeRequestSender handshakeRequestSender; private final HandshakeRequestSender handshakeRequestSender;
private final HandshakeResponseSender handshakeResponseSender;
TransportHandshaker(Version version, ThreadPool threadPool, HandshakeRequestSender handshakeRequestSender, TransportHandshaker(Version version, ThreadPool threadPool, HandshakeRequestSender handshakeRequestSender) {
HandshakeResponseSender handshakeResponseSender) {
this.version = version; this.version = version;
this.threadPool = threadPool; this.threadPool = threadPool;
this.handshakeRequestSender = handshakeRequestSender; this.handshakeRequestSender = handshakeRequestSender;
this.handshakeResponseSender = handshakeResponseSender;
} }
void sendHandshake(long requestId, DiscoveryNode node, TcpChannel channel, TimeValue timeout, ActionListener<Version> listener) { void sendHandshake(long requestId, DiscoveryNode node, TcpChannel channel, TimeValue timeout, ActionListener<Version> listener) {
@ -88,7 +84,7 @@ final class TransportHandshaker {
} }
} }
void handleHandshake(Version version, Set<String> features, TcpChannel channel, long requestId, StreamInput stream) throws IOException { void handleHandshake(TransportChannel channel, long requestId, StreamInput stream) throws IOException {
// Must read the handshake request to exhaust the stream // Must read the handshake request to exhaust the stream
HandshakeRequest handshakeRequest = new HandshakeRequest(stream); HandshakeRequest handshakeRequest = new HandshakeRequest(stream);
final int nextByte = stream.read(); final int nextByte = stream.read();
@ -96,8 +92,7 @@ final class TransportHandshaker {
throw new IllegalStateException("Handshake request not fully read for requestId [" + requestId + "], action [" throw new IllegalStateException("Handshake request not fully read for requestId [" + requestId + "], action ["
+ TransportHandshaker.HANDSHAKE_ACTION_NAME + "], available [" + stream.available() + "]; resetting"); + TransportHandshaker.HANDSHAKE_ACTION_NAME + "], available [" + stream.available() + "]; resetting");
} }
HandshakeResponse response = new HandshakeResponse(this.version); channel.sendResponse(new HandshakeResponse(this.version));
handshakeResponseSender.sendResponse(version, features, channel, response, requestId);
} }
TransportResponseHandler<HandshakeResponse> removeHandlerForHandshake(long requestId) { TransportResponseHandler<HandshakeResponse> removeHandlerForHandshake(long requestId) {
@ -228,11 +223,4 @@ final class TransportHandshaker {
void sendRequest(DiscoveryNode node, TcpChannel channel, long requestId, Version version) throws IOException; void sendRequest(DiscoveryNode node, TcpChannel channel, long requestId, Version version) throws IOException;
} }
@FunctionalInterface
interface HandshakeResponseSender {
void sendResponse(Version version, Set<String> features, TcpChannel channel, TransportResponse response, long requestId)
throws IOException;
}
} }

View File

@ -57,7 +57,7 @@ import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.transport.CapturingTransport; import org.elasticsearch.test.transport.CapturingTransport;
import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportChannel; import org.elasticsearch.transport.TestTransportChannel;
import org.elasticsearch.transport.TransportResponse; import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.transport.TransportService; import org.elasticsearch.transport.TransportService;
import org.junit.After; import org.junit.After;
@ -366,14 +366,15 @@ public class TransportBroadcastByNodeActionTests extends ESTestCase {
final TransportBroadcastByNodeAction.BroadcastByNodeTransportRequestHandler handler = final TransportBroadcastByNodeAction.BroadcastByNodeTransportRequestHandler handler =
action.new BroadcastByNodeTransportRequestHandler(); action.new BroadcastByNodeTransportRequestHandler();
TestTransportChannel channel = new TestTransportChannel(); final PlainActionFuture<TransportResponse> future = PlainActionFuture.newFuture();
TestTransportChannel channel = new TestTransportChannel(future);
handler.messageReceived(action.new NodeRequest(nodeId, new Request(), new ArrayList<>(shards)), channel, null); handler.messageReceived(action.new NodeRequest(nodeId, new Request(), new ArrayList<>(shards)), channel, null);
// check the operation was executed only on the expected shards // check the operation was executed only on the expected shards
assertEquals(shards, action.getResults().keySet()); assertEquals(shards, action.getResults().keySet());
TransportResponse response = channel.getCapturedResponse(); TransportResponse response = future.actionGet();
assertTrue(response instanceof TransportBroadcastByNodeAction.NodeResponse); assertTrue(response instanceof TransportBroadcastByNodeAction.NodeResponse);
TransportBroadcastByNodeAction.NodeResponse nodeResponse = (TransportBroadcastByNodeAction.NodeResponse) response; TransportBroadcastByNodeAction.NodeResponse nodeResponse = (TransportBroadcastByNodeAction.NodeResponse) response;
@ -469,32 +470,4 @@ public class TransportBroadcastByNodeActionTests extends ESTestCase {
assertEquals("failed shards", totalFailedShards, response.getFailedShards()); assertEquals("failed shards", totalFailedShards, response.getFailedShards());
assertEquals("accumulated exceptions", totalFailedShards, response.getShardFailures().length); assertEquals("accumulated exceptions", totalFailedShards, response.getShardFailures().length);
} }
public class TestTransportChannel implements TransportChannel {
private TransportResponse capturedResponse;
public TransportResponse getCapturedResponse() {
return capturedResponse;
}
@Override
public String getProfileName() {
return "";
}
@Override
public void sendResponse(TransportResponse response) throws IOException {
capturedResponse = response;
}
@Override
public void sendResponse(Exception exception) throws IOException {
}
@Override
public String getChannelType() {
return "test";
}
}
} }

View File

@ -78,6 +78,7 @@ import org.elasticsearch.test.transport.CapturingTransport;
import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.test.transport.MockTransportService;
import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TestTransportChannel;
import org.elasticsearch.transport.Transport; import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportChannel; import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportException; import org.elasticsearch.transport.TransportException;
@ -817,7 +818,7 @@ public class TransportReplicationActionTests extends ESTestCase {
Request request = new Request(shardId); Request request = new Request(shardId);
TransportReplicationAction.ConcreteShardRequest<Request> concreteShardRequest = TransportReplicationAction.ConcreteShardRequest<Request> concreteShardRequest =
new TransportReplicationAction.ConcreteShardRequest<>(request, routingEntry.allocationId().getId(), primaryTerm); new TransportReplicationAction.ConcreteShardRequest<>(request, routingEntry.allocationId().getId(), primaryTerm);
PlainActionFuture<TestResponse> listener = new PlainActionFuture<>(); PlainActionFuture<TransportResponse> listener = new PlainActionFuture<>();
final IndexShard shard = mockIndexShard(shardId, clusterService); final IndexShard shard = mockIndexShard(shardId, clusterService);
@ -981,7 +982,7 @@ public class TransportReplicationActionTests extends ESTestCase {
setState(clusterService, state(index, true, ShardRoutingState.STARTED)); setState(clusterService, state(index, true, ShardRoutingState.STARTED));
final ShardRouting primary = clusterService.state().routingTable().shardRoutingTable(shardId).primaryShard(); final ShardRouting primary = clusterService.state().routingTable().shardRoutingTable(shardId).primaryShard();
final long primaryTerm = clusterService.state().metadata().index(shardId.getIndexName()).primaryTerm(shardId.id()); final long primaryTerm = clusterService.state().metadata().index(shardId.getIndexName()).primaryTerm(shardId.id());
PlainActionFuture<TestResponse> listener = new PlainActionFuture<>(); PlainActionFuture<TransportResponse> listener = new PlainActionFuture<>();
final boolean wrongAllocationId = randomBoolean(); final boolean wrongAllocationId = randomBoolean();
final long requestTerm = wrongAllocationId && randomBoolean() ? primaryTerm : primaryTerm + randomIntBetween(1, 10); final long requestTerm = wrongAllocationId && randomBoolean() ? primaryTerm : primaryTerm + randomIntBetween(1, 10);
Request request = new Request(shardId).timeout("1ms"); Request request = new Request(shardId).timeout("1ms");
@ -1018,7 +1019,7 @@ public class TransportReplicationActionTests extends ESTestCase {
state = ClusterState.builder(state).nodes(DiscoveryNodes.builder(state.nodes()).localNodeId(replica.currentNodeId())).build(); state = ClusterState.builder(state).nodes(DiscoveryNodes.builder(state.nodes()).localNodeId(replica.currentNodeId())).build();
setState(clusterService, state); setState(clusterService, state);
PlainActionFuture<TestResponse> listener = new PlainActionFuture<>(); PlainActionFuture<TransportResponse> listener = new PlainActionFuture<>();
Request request = new Request(shardId).timeout("1ms"); Request request = new Request(shardId).timeout("1ms");
action.handleReplicaRequest( action.handleReplicaRequest(
new TransportReplicationAction.ConcreteReplicaRequest<>(request, "_not_a_valid_aid_", randomNonNegativeLong(), new TransportReplicationAction.ConcreteReplicaRequest<>(request, "_not_a_valid_aid_", randomNonNegativeLong(),
@ -1062,7 +1063,7 @@ public class TransportReplicationActionTests extends ESTestCase {
return new ReplicaResult(); return new ReplicaResult();
} }
}; };
final PlainActionFuture<TestResponse> listener = new PlainActionFuture<>(); final PlainActionFuture<TransportResponse> listener = new PlainActionFuture<>();
final Request request = new Request(shardId); final Request request = new Request(shardId);
final long checkpoint = randomNonNegativeLong(); final long checkpoint = randomNonNegativeLong();
final long maxSeqNoOfUpdatesOrDeletes = randomNonNegativeLong(); final long maxSeqNoOfUpdatesOrDeletes = randomNonNegativeLong();
@ -1130,7 +1131,7 @@ public class TransportReplicationActionTests extends ESTestCase {
return new ReplicaResult(); return new ReplicaResult();
} }
}; };
final PlainActionFuture<TestResponse> listener = new PlainActionFuture<>(); final PlainActionFuture<TransportResponse> listener = new PlainActionFuture<>();
final Request request = new Request(shardId); final Request request = new Request(shardId);
final long checkpoint = randomNonNegativeLong(); final long checkpoint = randomNonNegativeLong();
final long maxSeqNoOfUpdates = randomNonNegativeLong(); final long maxSeqNoOfUpdates = randomNonNegativeLong();
@ -1371,29 +1372,8 @@ public class TransportReplicationActionTests extends ESTestCase {
/** /**
* Transport channel that is needed for replica operation testing. * Transport channel that is needed for replica operation testing.
*/ */
public TransportChannel createTransportChannel(final PlainActionFuture<TestResponse> listener) { public TransportChannel createTransportChannel(final PlainActionFuture<TransportResponse> listener) {
return new TransportChannel() { return new TestTransportChannel(listener);
@Override
public String getProfileName() {
return "";
}
@Override
public void sendResponse(TransportResponse response) {
listener.onResponse(((TestResponse) response));
}
@Override
public void sendResponse(Exception exception) {
listener.onFailure(exception);
}
@Override
public String getChannelType() {
return "replica_test";
}
};
} }
} }

View File

@ -20,6 +20,7 @@ package org.elasticsearch.cluster.coordination;
import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.Version; import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ESAllocationTestCase; import org.elasticsearch.cluster.ESAllocationTestCase;
@ -44,8 +45,8 @@ import org.elasticsearch.test.transport.CapturingTransport;
import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.RequestHandlerRegistry; import org.elasticsearch.transport.RequestHandlerRegistry;
import org.elasticsearch.transport.TestTransportChannel;
import org.elasticsearch.transport.Transport; import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportRequest; import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportResponse; import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.transport.TransportService; import org.elasticsearch.transport.TransportService;
@ -229,29 +230,22 @@ public class NodeJoinTests extends ESTestCase {
try { try {
final RequestHandlerRegistry<JoinRequest> joinHandler = (RequestHandlerRegistry<JoinRequest>) final RequestHandlerRegistry<JoinRequest> joinHandler = (RequestHandlerRegistry<JoinRequest>)
transport.getRequestHandler(JoinHelper.JOIN_ACTION_NAME); transport.getRequestHandler(JoinHelper.JOIN_ACTION_NAME);
joinHandler.processMessageReceived(joinRequest, new TransportChannel() { final ActionListener<TransportResponse> listener = new ActionListener<TransportResponse>() {
@Override
public String getProfileName() {
return "dummy";
}
@Override @Override
public String getChannelType() { public void onResponse(TransportResponse transportResponse) {
return "dummy";
}
@Override
public void sendResponse(TransportResponse response) {
logger.debug("{} completed", future); logger.debug("{} completed", future);
future.markAsDone(); future.markAsDone();
} }
@Override @Override
public void sendResponse(Exception e) { public void onFailure(Exception e) {
logger.error(() -> new ParameterizedMessage("unexpected error for {}", future), e); logger.error(() -> new ParameterizedMessage("unexpected error for {}", future), e);
future.markAsFailed(e); future.markAsFailed(e);
} }
}); };
joinHandler.processMessageReceived(joinRequest, new TestTransportChannel(listener));
} catch (Exception e) { } catch (Exception e) {
logger.error(() -> new ParameterizedMessage("unexpected error for {}", future), e); logger.error(() -> new ParameterizedMessage("unexpected error for {}", future), e);
future.markAsFailed(e); future.markAsFailed(e);
@ -402,27 +396,17 @@ public class NodeJoinTests extends ESTestCase {
private void handleStartJoinFrom(DiscoveryNode node, long term) throws Exception { private void handleStartJoinFrom(DiscoveryNode node, long term) throws Exception {
final RequestHandlerRegistry<StartJoinRequest> startJoinHandler = (RequestHandlerRegistry<StartJoinRequest>) final RequestHandlerRegistry<StartJoinRequest> startJoinHandler = (RequestHandlerRegistry<StartJoinRequest>)
transport.getRequestHandler(JoinHelper.START_JOIN_ACTION_NAME); transport.getRequestHandler(JoinHelper.START_JOIN_ACTION_NAME);
startJoinHandler.processMessageReceived(new StartJoinRequest(node, term), new TransportChannel() { startJoinHandler.processMessageReceived(new StartJoinRequest(node, term), new TestTransportChannel(
@Override new ActionListener<TransportResponse>() {
public String getProfileName() { @Override
return "dummy"; public void onResponse(TransportResponse transportResponse) {
} }
@Override @Override
public String getChannelType() { public void onFailure(Exception e) {
return "dummy";
}
@Override
public void sendResponse(TransportResponse response) {
}
@Override
public void sendResponse(Exception exception) {
fail(); fail();
} }
}); }));
deterministicTaskQueue.runAllRunnableTasks(); deterministicTaskQueue.runAllRunnableTasks();
assertFalse(isLocalNodeElectedMaster()); assertFalse(isLocalNodeElectedMaster());
assertThat(coordinator.getMode(), equalTo(Coordinator.Mode.CANDIDATE)); assertThat(coordinator.getMode(), equalTo(Coordinator.Mode.CANDIDATE));
@ -432,27 +416,19 @@ public class NodeJoinTests extends ESTestCase {
final RequestHandlerRegistry<FollowersChecker.FollowerCheckRequest> followerCheckHandler = final RequestHandlerRegistry<FollowersChecker.FollowerCheckRequest> followerCheckHandler =
(RequestHandlerRegistry<FollowersChecker.FollowerCheckRequest>) (RequestHandlerRegistry<FollowersChecker.FollowerCheckRequest>)
transport.getRequestHandler(FollowersChecker.FOLLOWER_CHECK_ACTION_NAME); transport.getRequestHandler(FollowersChecker.FOLLOWER_CHECK_ACTION_NAME);
followerCheckHandler.processMessageReceived(new FollowersChecker.FollowerCheckRequest(term, node), new TransportChannel() { final TestTransportChannel channel = new TestTransportChannel(new ActionListener<TransportResponse>() {
@Override @Override
public String getProfileName() { public void onResponse(TransportResponse transportResponse) {
return "dummy";
}
@Override
public String getChannelType() {
return "dummy";
}
@Override
public void sendResponse(TransportResponse response) {
} }
@Override @Override
public void sendResponse(Exception exception) { public void onFailure(Exception e) {
fail(); fail();
} }
}); });
followerCheckHandler.processMessageReceived(new FollowersChecker.FollowerCheckRequest(term, node), channel);
// Will throw exception if failed
deterministicTaskQueue.runAllRunnableTasks(); deterministicTaskQueue.runAllRunnableTasks();
assertFalse(isLocalNodeElectedMaster()); assertFalse(isLocalNodeElectedMaster());
assertThat(coordinator.getMode(), equalTo(Coordinator.Mode.FOLLOWER)); assertThat(coordinator.getMode(), equalTo(Coordinator.Mode.FOLLOWER));

View File

@ -58,7 +58,7 @@ public class InboundHandlerTests extends ESTestCase {
taskManager = new TaskManager(Settings.EMPTY, threadPool, Collections.emptySet()); taskManager = new TaskManager(Settings.EMPTY, threadPool, Collections.emptySet());
channel = new FakeTcpChannel(randomBoolean(), buildNewFakeTransportAddress().address(), buildNewFakeTransportAddress().address()); channel = new FakeTcpChannel(randomBoolean(), buildNewFakeTransportAddress().address(), buildNewFakeTransportAddress().address());
NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList());
TransportHandshaker handshaker = new TransportHandshaker(version, threadPool, (n, c, r, v) -> {}, (v, f, c, r, r_id) -> {}); TransportHandshaker handshaker = new TransportHandshaker(version, threadPool, (n, c, r, v) -> {});
TransportKeepAlive keepAlive = new TransportKeepAlive(threadPool, TcpChannel::sendMessage); TransportKeepAlive keepAlive = new TransportKeepAlive(threadPool, TcpChannel::sendMessage);
OutboundHandler outboundHandler = new OutboundHandler("node", version, new String[0], new StatsTracker(), threadPool, OutboundHandler outboundHandler = new OutboundHandler("node", version, new String[0], new StatsTracker(), threadPool,
BigArrays.NON_RECYCLING_INSTANCE); BigArrays.NON_RECYCLING_INSTANCE);

View File

@ -27,14 +27,12 @@ import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.tasks.TaskId; import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.TestThreadPool;
import org.mockito.ArgumentCaptor;
import java.io.IOException; import java.io.IOException;
import java.util.Collections; import java.util.Collections;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -46,7 +44,6 @@ public class TransportHandshakerTests extends ESTestCase {
private TcpChannel channel; private TcpChannel channel;
private TestThreadPool threadPool; private TestThreadPool threadPool;
private TransportHandshaker.HandshakeRequestSender requestSender; private TransportHandshaker.HandshakeRequestSender requestSender;
private TransportHandshaker.HandshakeResponseSender responseSender;
@Override @Override
public void setUp() throws Exception { public void setUp() throws Exception {
@ -54,11 +51,10 @@ public class TransportHandshakerTests extends ESTestCase {
String nodeId = "node-id"; String nodeId = "node-id";
channel = mock(TcpChannel.class); channel = mock(TcpChannel.class);
requestSender = mock(TransportHandshaker.HandshakeRequestSender.class); requestSender = mock(TransportHandshaker.HandshakeRequestSender.class);
responseSender = mock(TransportHandshaker.HandshakeResponseSender.class);
node = new DiscoveryNode(nodeId, nodeId, nodeId, "host", "host_address", buildNewFakeTransportAddress(), Collections.emptyMap(), node = new DiscoveryNode(nodeId, nodeId, nodeId, "host", "host_address", buildNewFakeTransportAddress(), Collections.emptyMap(),
Collections.emptySet(), Version.CURRENT); Collections.emptySet(), Version.CURRENT);
threadPool = new TestThreadPool("thread-poll"); threadPool = new TestThreadPool("thread-poll");
handshaker = new TransportHandshaker(Version.CURRENT, threadPool, requestSender, responseSender); handshaker = new TransportHandshaker(Version.CURRENT, threadPool, requestSender);
} }
@Override @Override
@ -76,20 +72,16 @@ public class TransportHandshakerTests extends ESTestCase {
assertFalse(versionFuture.isDone()); assertFalse(versionFuture.isDone());
TcpChannel mockChannel = mock(TcpChannel.class);
TransportHandshaker.HandshakeRequest handshakeRequest = new TransportHandshaker.HandshakeRequest(Version.CURRENT); TransportHandshaker.HandshakeRequest handshakeRequest = new TransportHandshaker.HandshakeRequest(Version.CURRENT);
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
handshakeRequest.writeTo(bytesStreamOutput); handshakeRequest.writeTo(bytesStreamOutput);
StreamInput input = bytesStreamOutput.bytes().streamInput(); StreamInput input = bytesStreamOutput.bytes().streamInput();
handshaker.handleHandshake(Version.CURRENT, Collections.emptySet(), mockChannel, reqId, input); final PlainActionFuture<TransportResponse> responseFuture = PlainActionFuture.newFuture();
final TestTransportChannel channel = new TestTransportChannel(responseFuture);
handshaker.handleHandshake(channel, reqId, input);
ArgumentCaptor<TransportResponse> responseCaptor = ArgumentCaptor.forClass(TransportResponse.class);
verify(responseSender).sendResponse(eq(Version.CURRENT), eq(Collections.emptySet()), eq(mockChannel), responseCaptor.capture(),
eq(reqId));
TransportResponseHandler<TransportHandshaker.HandshakeResponse> handler = handshaker.removeHandlerForHandshake(reqId); TransportResponseHandler<TransportHandshaker.HandshakeResponse> handler = handshaker.removeHandlerForHandshake(reqId);
handler.handleResponse((TransportHandshaker.HandshakeResponse) responseCaptor.getValue()); handler.handleResponse((TransportHandshaker.HandshakeResponse) responseFuture.actionGet());
assertTrue(versionFuture.isDone()); assertTrue(versionFuture.isDone());
assertEquals(Version.CURRENT, versionFuture.actionGet()); assertEquals(Version.CURRENT, versionFuture.actionGet());
@ -101,7 +93,6 @@ public class TransportHandshakerTests extends ESTestCase {
verify(requestSender).sendRequest(node, channel, reqId, Version.CURRENT.minimumCompatibilityVersion()); verify(requestSender).sendRequest(node, channel, reqId, Version.CURRENT.minimumCompatibilityVersion());
TcpChannel mockChannel = mock(TcpChannel.class);
TransportHandshaker.HandshakeRequest handshakeRequest = new TransportHandshaker.HandshakeRequest(Version.CURRENT); TransportHandshaker.HandshakeRequest handshakeRequest = new TransportHandshaker.HandshakeRequest(Version.CURRENT);
BytesStreamOutput currentHandshakeBytes = new BytesStreamOutput(); BytesStreamOutput currentHandshakeBytes = new BytesStreamOutput();
handshakeRequest.writeTo(currentHandshakeBytes); handshakeRequest.writeTo(currentHandshakeBytes);
@ -121,15 +112,12 @@ public class TransportHandshakerTests extends ESTestCase {
// Otherwise, we need to update the test. // Otherwise, we need to update the test.
assertEquals(currentHandshakeBytes.bytes().length(), lengthCheckingHandshake.bytes().length()); assertEquals(currentHandshakeBytes.bytes().length(), lengthCheckingHandshake.bytes().length());
assertEquals(1031, futureHandshakeStream.available()); assertEquals(1031, futureHandshakeStream.available());
handshaker.handleHandshake(Version.CURRENT, Collections.emptySet(), mockChannel, reqId, futureHandshakeStream); final PlainActionFuture<TransportResponse> responseFuture = PlainActionFuture.newFuture();
final TestTransportChannel channel = new TestTransportChannel(responseFuture);
handshaker.handleHandshake(channel, reqId, futureHandshakeStream);
assertEquals(0, futureHandshakeStream.available()); assertEquals(0, futureHandshakeStream.available());
TransportHandshaker.HandshakeResponse response = (TransportHandshaker.HandshakeResponse) responseFuture.actionGet();
ArgumentCaptor<TransportResponse> responseCaptor = ArgumentCaptor.forClass(TransportResponse.class);
verify(responseSender).sendResponse(eq(Version.CURRENT), eq(Collections.emptySet()), eq(mockChannel), responseCaptor.capture(),
eq(reqId));
TransportHandshaker.HandshakeResponse response = (TransportHandshaker.HandshakeResponse) responseCaptor.getValue();
assertEquals(Version.CURRENT, response.getResponseVersion()); assertEquals(Version.CURRENT, response.getResponseVersion());
} }

View File

@ -0,0 +1,51 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport;
import org.elasticsearch.action.ActionListener;
public class TestTransportChannel implements TransportChannel {
private final ActionListener<TransportResponse> listener;
public TestTransportChannel(ActionListener<TransportResponse> listener) {
this.listener = listener;
}
@Override
public String getProfileName() {
return "default";
}
@Override
public void sendResponse(TransportResponse response) {
listener.onResponse(response);
}
@Override
public void sendResponse(Exception exception) {
listener.onFailure(exception);
}
@Override
public String getChannelType() {
return "test";
}
}