Use TransportChannel in TransportHandshaker (#54921)
Currently the TransportHandshaker has a specialized codepath for sending a response. In other work, we are going to start having handshakes contribute to circuit breaking (while not being breakable). This commit moves in that direction by allowing the handshaker to responding using a standard TcpTransportChannel similar to other requests.
This commit is contained in:
parent
ce7ae4a7d1
commit
c7053ef824
|
@ -157,7 +157,10 @@ public class InboundHandler {
|
||||||
try {
|
try {
|
||||||
messageListener.onRequestReceived(requestId, action);
|
messageListener.onRequestReceived(requestId, action);
|
||||||
if (header.isHandshake()) {
|
if (header.isHandshake()) {
|
||||||
handshaker.handleHandshake(version, features, channel, requestId, stream);
|
// Handshakes are not currently circuit broken
|
||||||
|
transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, features,
|
||||||
|
circuitBreakerService, 0, header.isCompressed(), header.isHandshake());
|
||||||
|
handshaker.handleHandshake(transportChannel, requestId, stream);
|
||||||
} else {
|
} else {
|
||||||
final RequestHandlerRegistry<T> reg = getRequestHandler(action);
|
final RequestHandlerRegistry<T> reg = getRequestHandler(action);
|
||||||
if (reg == null) {
|
if (reg == null) {
|
||||||
|
@ -170,7 +173,7 @@ public class InboundHandler {
|
||||||
breaker.addWithoutBreaking(messageLengthBytes);
|
breaker.addWithoutBreaking(messageLengthBytes);
|
||||||
}
|
}
|
||||||
transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, features,
|
transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, features,
|
||||||
circuitBreakerService, messageLengthBytes, header.isCompressed());
|
circuitBreakerService, messageLengthBytes, header.isCompressed(), header.isHandshake());
|
||||||
final T request = reg.newRequest(stream);
|
final T request = reg.newRequest(stream);
|
||||||
request.remoteAddress(new TransportAddress(channel.getRemoteAddress()));
|
request.remoteAddress(new TransportAddress(channel.getRemoteAddress()));
|
||||||
// in case we throw an exception, i.e. when the limit is hit, we don't want to verify
|
// in case we throw an exception, i.e. when the limit is hit, we don't want to verify
|
||||||
|
@ -186,7 +189,7 @@ public class InboundHandler {
|
||||||
// the circuit breaker tripped
|
// the circuit breaker tripped
|
||||||
if (transportChannel == null) {
|
if (transportChannel == null) {
|
||||||
transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, features,
|
transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, features,
|
||||||
circuitBreakerService, 0, header.isCompressed());
|
circuitBreakerService, 0, header.isCompressed(), header.isHandshake());
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
transportChannel.sendResponse(e);
|
transportChannel.sendResponse(e);
|
||||||
|
|
|
@ -159,9 +159,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
|
||||||
this.handshaker = new TransportHandshaker(version, threadPool,
|
this.handshaker = new TransportHandshaker(version, threadPool,
|
||||||
(node, channel, requestId, v) -> outboundHandler.sendRequest(node, channel, requestId,
|
(node, channel, requestId, v) -> outboundHandler.sendRequest(node, channel, requestId,
|
||||||
TransportHandshaker.HANDSHAKE_ACTION_NAME, new TransportHandshaker.HandshakeRequest(version),
|
TransportHandshaker.HANDSHAKE_ACTION_NAME, new TransportHandshaker.HandshakeRequest(version),
|
||||||
TransportRequestOptions.EMPTY, v, false, true),
|
TransportRequestOptions.EMPTY, v, false, true));
|
||||||
(v, features1, channel, response, requestId) -> outboundHandler.sendResponse(v, features1, channel, requestId,
|
|
||||||
TransportHandshaker.HANDSHAKE_ACTION_NAME, response, false, true));
|
|
||||||
this.keepAlive = new TransportKeepAlive(threadPool, this.outboundHandler::sendBytes);
|
this.keepAlive = new TransportKeepAlive(threadPool, this.outboundHandler::sendBytes);
|
||||||
this.inboundHandler = new InboundHandler(threadPool, outboundHandler, namedWriteableRegistry, circuitBreakerService, handshaker,
|
this.inboundHandler = new InboundHandler(threadPool, outboundHandler, namedWriteableRegistry, circuitBreakerService, handshaker,
|
||||||
keepAlive);
|
keepAlive);
|
||||||
|
|
|
@ -39,9 +39,11 @@ public final class TcpTransportChannel implements TransportChannel {
|
||||||
private final CircuitBreakerService breakerService;
|
private final CircuitBreakerService breakerService;
|
||||||
private final long reservedBytes;
|
private final long reservedBytes;
|
||||||
private final boolean compressResponse;
|
private final boolean compressResponse;
|
||||||
|
private final boolean isHandshake;
|
||||||
|
|
||||||
TcpTransportChannel(OutboundHandler outboundHandler, TcpChannel channel, String action, long requestId, Version version,
|
TcpTransportChannel(OutboundHandler outboundHandler, TcpChannel channel, String action, long requestId, Version version,
|
||||||
Set<String> features, CircuitBreakerService breakerService, long reservedBytes, boolean compressResponse) {
|
Set<String> features, CircuitBreakerService breakerService, long reservedBytes, boolean compressResponse,
|
||||||
|
boolean isHandshake) {
|
||||||
this.version = version;
|
this.version = version;
|
||||||
this.features = features;
|
this.features = features;
|
||||||
this.channel = channel;
|
this.channel = channel;
|
||||||
|
@ -51,6 +53,7 @@ public final class TcpTransportChannel implements TransportChannel {
|
||||||
this.breakerService = breakerService;
|
this.breakerService = breakerService;
|
||||||
this.reservedBytes = reservedBytes;
|
this.reservedBytes = reservedBytes;
|
||||||
this.compressResponse = compressResponse;
|
this.compressResponse = compressResponse;
|
||||||
|
this.isHandshake = isHandshake;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -61,7 +64,7 @@ public final class TcpTransportChannel implements TransportChannel {
|
||||||
@Override
|
@Override
|
||||||
public void sendResponse(TransportResponse response) throws IOException {
|
public void sendResponse(TransportResponse response) throws IOException {
|
||||||
try {
|
try {
|
||||||
outboundHandler.sendResponse(version, features, channel, requestId, action, response, compressResponse, false);
|
outboundHandler.sendResponse(version, features, channel, requestId, action, response, compressResponse, isHandshake);
|
||||||
} finally {
|
} finally {
|
||||||
release(false);
|
release(false);
|
||||||
}
|
}
|
||||||
|
@ -102,6 +105,5 @@ public final class TcpTransportChannel implements TransportChannel {
|
||||||
public TcpChannel getChannel() {
|
public TcpChannel getChannel() {
|
||||||
return channel;
|
return channel;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,6 @@ import org.elasticsearch.threadpool.ThreadPool;
|
||||||
|
|
||||||
import java.io.EOFException;
|
import java.io.EOFException;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Set;
|
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
import java.util.concurrent.ConcurrentMap;
|
import java.util.concurrent.ConcurrentMap;
|
||||||
import java.util.concurrent.atomic.AtomicBoolean;
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
|
@ -49,14 +48,11 @@ final class TransportHandshaker {
|
||||||
private final Version version;
|
private final Version version;
|
||||||
private final ThreadPool threadPool;
|
private final ThreadPool threadPool;
|
||||||
private final HandshakeRequestSender handshakeRequestSender;
|
private final HandshakeRequestSender handshakeRequestSender;
|
||||||
private final HandshakeResponseSender handshakeResponseSender;
|
|
||||||
|
|
||||||
TransportHandshaker(Version version, ThreadPool threadPool, HandshakeRequestSender handshakeRequestSender,
|
TransportHandshaker(Version version, ThreadPool threadPool, HandshakeRequestSender handshakeRequestSender) {
|
||||||
HandshakeResponseSender handshakeResponseSender) {
|
|
||||||
this.version = version;
|
this.version = version;
|
||||||
this.threadPool = threadPool;
|
this.threadPool = threadPool;
|
||||||
this.handshakeRequestSender = handshakeRequestSender;
|
this.handshakeRequestSender = handshakeRequestSender;
|
||||||
this.handshakeResponseSender = handshakeResponseSender;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void sendHandshake(long requestId, DiscoveryNode node, TcpChannel channel, TimeValue timeout, ActionListener<Version> listener) {
|
void sendHandshake(long requestId, DiscoveryNode node, TcpChannel channel, TimeValue timeout, ActionListener<Version> listener) {
|
||||||
|
@ -88,7 +84,7 @@ final class TransportHandshaker {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void handleHandshake(Version version, Set<String> features, TcpChannel channel, long requestId, StreamInput stream) throws IOException {
|
void handleHandshake(TransportChannel channel, long requestId, StreamInput stream) throws IOException {
|
||||||
// Must read the handshake request to exhaust the stream
|
// Must read the handshake request to exhaust the stream
|
||||||
HandshakeRequest handshakeRequest = new HandshakeRequest(stream);
|
HandshakeRequest handshakeRequest = new HandshakeRequest(stream);
|
||||||
final int nextByte = stream.read();
|
final int nextByte = stream.read();
|
||||||
|
@ -96,8 +92,7 @@ final class TransportHandshaker {
|
||||||
throw new IllegalStateException("Handshake request not fully read for requestId [" + requestId + "], action ["
|
throw new IllegalStateException("Handshake request not fully read for requestId [" + requestId + "], action ["
|
||||||
+ TransportHandshaker.HANDSHAKE_ACTION_NAME + "], available [" + stream.available() + "]; resetting");
|
+ TransportHandshaker.HANDSHAKE_ACTION_NAME + "], available [" + stream.available() + "]; resetting");
|
||||||
}
|
}
|
||||||
HandshakeResponse response = new HandshakeResponse(this.version);
|
channel.sendResponse(new HandshakeResponse(this.version));
|
||||||
handshakeResponseSender.sendResponse(version, features, channel, response, requestId);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TransportResponseHandler<HandshakeResponse> removeHandlerForHandshake(long requestId) {
|
TransportResponseHandler<HandshakeResponse> removeHandlerForHandshake(long requestId) {
|
||||||
|
@ -228,11 +223,4 @@ final class TransportHandshaker {
|
||||||
|
|
||||||
void sendRequest(DiscoveryNode node, TcpChannel channel, long requestId, Version version) throws IOException;
|
void sendRequest(DiscoveryNode node, TcpChannel channel, long requestId, Version version) throws IOException;
|
||||||
}
|
}
|
||||||
|
|
||||||
@FunctionalInterface
|
|
||||||
interface HandshakeResponseSender {
|
|
||||||
|
|
||||||
void sendResponse(Version version, Set<String> features, TcpChannel channel, TransportResponse response, long requestId)
|
|
||||||
throws IOException;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -57,7 +57,7 @@ import org.elasticsearch.test.ESTestCase;
|
||||||
import org.elasticsearch.test.transport.CapturingTransport;
|
import org.elasticsearch.test.transport.CapturingTransport;
|
||||||
import org.elasticsearch.threadpool.TestThreadPool;
|
import org.elasticsearch.threadpool.TestThreadPool;
|
||||||
import org.elasticsearch.threadpool.ThreadPool;
|
import org.elasticsearch.threadpool.ThreadPool;
|
||||||
import org.elasticsearch.transport.TransportChannel;
|
import org.elasticsearch.transport.TestTransportChannel;
|
||||||
import org.elasticsearch.transport.TransportResponse;
|
import org.elasticsearch.transport.TransportResponse;
|
||||||
import org.elasticsearch.transport.TransportService;
|
import org.elasticsearch.transport.TransportService;
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
|
@ -366,14 +366,15 @@ public class TransportBroadcastByNodeActionTests extends ESTestCase {
|
||||||
final TransportBroadcastByNodeAction.BroadcastByNodeTransportRequestHandler handler =
|
final TransportBroadcastByNodeAction.BroadcastByNodeTransportRequestHandler handler =
|
||||||
action.new BroadcastByNodeTransportRequestHandler();
|
action.new BroadcastByNodeTransportRequestHandler();
|
||||||
|
|
||||||
TestTransportChannel channel = new TestTransportChannel();
|
final PlainActionFuture<TransportResponse> future = PlainActionFuture.newFuture();
|
||||||
|
TestTransportChannel channel = new TestTransportChannel(future);
|
||||||
|
|
||||||
handler.messageReceived(action.new NodeRequest(nodeId, new Request(), new ArrayList<>(shards)), channel, null);
|
handler.messageReceived(action.new NodeRequest(nodeId, new Request(), new ArrayList<>(shards)), channel, null);
|
||||||
|
|
||||||
// check the operation was executed only on the expected shards
|
// check the operation was executed only on the expected shards
|
||||||
assertEquals(shards, action.getResults().keySet());
|
assertEquals(shards, action.getResults().keySet());
|
||||||
|
|
||||||
TransportResponse response = channel.getCapturedResponse();
|
TransportResponse response = future.actionGet();
|
||||||
assertTrue(response instanceof TransportBroadcastByNodeAction.NodeResponse);
|
assertTrue(response instanceof TransportBroadcastByNodeAction.NodeResponse);
|
||||||
TransportBroadcastByNodeAction.NodeResponse nodeResponse = (TransportBroadcastByNodeAction.NodeResponse) response;
|
TransportBroadcastByNodeAction.NodeResponse nodeResponse = (TransportBroadcastByNodeAction.NodeResponse) response;
|
||||||
|
|
||||||
|
@ -469,32 +470,4 @@ public class TransportBroadcastByNodeActionTests extends ESTestCase {
|
||||||
assertEquals("failed shards", totalFailedShards, response.getFailedShards());
|
assertEquals("failed shards", totalFailedShards, response.getFailedShards());
|
||||||
assertEquals("accumulated exceptions", totalFailedShards, response.getShardFailures().length);
|
assertEquals("accumulated exceptions", totalFailedShards, response.getShardFailures().length);
|
||||||
}
|
}
|
||||||
|
|
||||||
public class TestTransportChannel implements TransportChannel {
|
|
||||||
private TransportResponse capturedResponse;
|
|
||||||
|
|
||||||
public TransportResponse getCapturedResponse() {
|
|
||||||
return capturedResponse;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getProfileName() {
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void sendResponse(TransportResponse response) throws IOException {
|
|
||||||
capturedResponse = response;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void sendResponse(Exception exception) throws IOException {
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getChannelType() {
|
|
||||||
return "test";
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -78,6 +78,7 @@ import org.elasticsearch.test.transport.CapturingTransport;
|
||||||
import org.elasticsearch.test.transport.MockTransportService;
|
import org.elasticsearch.test.transport.MockTransportService;
|
||||||
import org.elasticsearch.threadpool.TestThreadPool;
|
import org.elasticsearch.threadpool.TestThreadPool;
|
||||||
import org.elasticsearch.threadpool.ThreadPool;
|
import org.elasticsearch.threadpool.ThreadPool;
|
||||||
|
import org.elasticsearch.transport.TestTransportChannel;
|
||||||
import org.elasticsearch.transport.Transport;
|
import org.elasticsearch.transport.Transport;
|
||||||
import org.elasticsearch.transport.TransportChannel;
|
import org.elasticsearch.transport.TransportChannel;
|
||||||
import org.elasticsearch.transport.TransportException;
|
import org.elasticsearch.transport.TransportException;
|
||||||
|
@ -817,7 +818,7 @@ public class TransportReplicationActionTests extends ESTestCase {
|
||||||
Request request = new Request(shardId);
|
Request request = new Request(shardId);
|
||||||
TransportReplicationAction.ConcreteShardRequest<Request> concreteShardRequest =
|
TransportReplicationAction.ConcreteShardRequest<Request> concreteShardRequest =
|
||||||
new TransportReplicationAction.ConcreteShardRequest<>(request, routingEntry.allocationId().getId(), primaryTerm);
|
new TransportReplicationAction.ConcreteShardRequest<>(request, routingEntry.allocationId().getId(), primaryTerm);
|
||||||
PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
|
PlainActionFuture<TransportResponse> listener = new PlainActionFuture<>();
|
||||||
|
|
||||||
|
|
||||||
final IndexShard shard = mockIndexShard(shardId, clusterService);
|
final IndexShard shard = mockIndexShard(shardId, clusterService);
|
||||||
|
@ -981,7 +982,7 @@ public class TransportReplicationActionTests extends ESTestCase {
|
||||||
setState(clusterService, state(index, true, ShardRoutingState.STARTED));
|
setState(clusterService, state(index, true, ShardRoutingState.STARTED));
|
||||||
final ShardRouting primary = clusterService.state().routingTable().shardRoutingTable(shardId).primaryShard();
|
final ShardRouting primary = clusterService.state().routingTable().shardRoutingTable(shardId).primaryShard();
|
||||||
final long primaryTerm = clusterService.state().metadata().index(shardId.getIndexName()).primaryTerm(shardId.id());
|
final long primaryTerm = clusterService.state().metadata().index(shardId.getIndexName()).primaryTerm(shardId.id());
|
||||||
PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
|
PlainActionFuture<TransportResponse> listener = new PlainActionFuture<>();
|
||||||
final boolean wrongAllocationId = randomBoolean();
|
final boolean wrongAllocationId = randomBoolean();
|
||||||
final long requestTerm = wrongAllocationId && randomBoolean() ? primaryTerm : primaryTerm + randomIntBetween(1, 10);
|
final long requestTerm = wrongAllocationId && randomBoolean() ? primaryTerm : primaryTerm + randomIntBetween(1, 10);
|
||||||
Request request = new Request(shardId).timeout("1ms");
|
Request request = new Request(shardId).timeout("1ms");
|
||||||
|
@ -1018,7 +1019,7 @@ public class TransportReplicationActionTests extends ESTestCase {
|
||||||
state = ClusterState.builder(state).nodes(DiscoveryNodes.builder(state.nodes()).localNodeId(replica.currentNodeId())).build();
|
state = ClusterState.builder(state).nodes(DiscoveryNodes.builder(state.nodes()).localNodeId(replica.currentNodeId())).build();
|
||||||
setState(clusterService, state);
|
setState(clusterService, state);
|
||||||
|
|
||||||
PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
|
PlainActionFuture<TransportResponse> listener = new PlainActionFuture<>();
|
||||||
Request request = new Request(shardId).timeout("1ms");
|
Request request = new Request(shardId).timeout("1ms");
|
||||||
action.handleReplicaRequest(
|
action.handleReplicaRequest(
|
||||||
new TransportReplicationAction.ConcreteReplicaRequest<>(request, "_not_a_valid_aid_", randomNonNegativeLong(),
|
new TransportReplicationAction.ConcreteReplicaRequest<>(request, "_not_a_valid_aid_", randomNonNegativeLong(),
|
||||||
|
@ -1062,7 +1063,7 @@ public class TransportReplicationActionTests extends ESTestCase {
|
||||||
return new ReplicaResult();
|
return new ReplicaResult();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
final PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
|
final PlainActionFuture<TransportResponse> listener = new PlainActionFuture<>();
|
||||||
final Request request = new Request(shardId);
|
final Request request = new Request(shardId);
|
||||||
final long checkpoint = randomNonNegativeLong();
|
final long checkpoint = randomNonNegativeLong();
|
||||||
final long maxSeqNoOfUpdatesOrDeletes = randomNonNegativeLong();
|
final long maxSeqNoOfUpdatesOrDeletes = randomNonNegativeLong();
|
||||||
|
@ -1130,7 +1131,7 @@ public class TransportReplicationActionTests extends ESTestCase {
|
||||||
return new ReplicaResult();
|
return new ReplicaResult();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
final PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
|
final PlainActionFuture<TransportResponse> listener = new PlainActionFuture<>();
|
||||||
final Request request = new Request(shardId);
|
final Request request = new Request(shardId);
|
||||||
final long checkpoint = randomNonNegativeLong();
|
final long checkpoint = randomNonNegativeLong();
|
||||||
final long maxSeqNoOfUpdates = randomNonNegativeLong();
|
final long maxSeqNoOfUpdates = randomNonNegativeLong();
|
||||||
|
@ -1371,29 +1372,8 @@ public class TransportReplicationActionTests extends ESTestCase {
|
||||||
/**
|
/**
|
||||||
* Transport channel that is needed for replica operation testing.
|
* Transport channel that is needed for replica operation testing.
|
||||||
*/
|
*/
|
||||||
public TransportChannel createTransportChannel(final PlainActionFuture<TestResponse> listener) {
|
public TransportChannel createTransportChannel(final PlainActionFuture<TransportResponse> listener) {
|
||||||
return new TransportChannel() {
|
return new TestTransportChannel(listener);
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getProfileName() {
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void sendResponse(TransportResponse response) {
|
|
||||||
listener.onResponse(((TestResponse) response));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void sendResponse(Exception exception) {
|
|
||||||
listener.onFailure(exception);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getChannelType() {
|
|
||||||
return "replica_test";
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,7 @@ package org.elasticsearch.cluster.coordination;
|
||||||
|
|
||||||
import org.apache.logging.log4j.message.ParameterizedMessage;
|
import org.apache.logging.log4j.message.ParameterizedMessage;
|
||||||
import org.elasticsearch.Version;
|
import org.elasticsearch.Version;
|
||||||
|
import org.elasticsearch.action.ActionListener;
|
||||||
import org.elasticsearch.cluster.ClusterName;
|
import org.elasticsearch.cluster.ClusterName;
|
||||||
import org.elasticsearch.cluster.ClusterState;
|
import org.elasticsearch.cluster.ClusterState;
|
||||||
import org.elasticsearch.cluster.ESAllocationTestCase;
|
import org.elasticsearch.cluster.ESAllocationTestCase;
|
||||||
|
@ -44,8 +45,8 @@ import org.elasticsearch.test.transport.CapturingTransport;
|
||||||
import org.elasticsearch.threadpool.TestThreadPool;
|
import org.elasticsearch.threadpool.TestThreadPool;
|
||||||
import org.elasticsearch.threadpool.ThreadPool;
|
import org.elasticsearch.threadpool.ThreadPool;
|
||||||
import org.elasticsearch.transport.RequestHandlerRegistry;
|
import org.elasticsearch.transport.RequestHandlerRegistry;
|
||||||
|
import org.elasticsearch.transport.TestTransportChannel;
|
||||||
import org.elasticsearch.transport.Transport;
|
import org.elasticsearch.transport.Transport;
|
||||||
import org.elasticsearch.transport.TransportChannel;
|
|
||||||
import org.elasticsearch.transport.TransportRequest;
|
import org.elasticsearch.transport.TransportRequest;
|
||||||
import org.elasticsearch.transport.TransportResponse;
|
import org.elasticsearch.transport.TransportResponse;
|
||||||
import org.elasticsearch.transport.TransportService;
|
import org.elasticsearch.transport.TransportService;
|
||||||
|
@ -229,29 +230,22 @@ public class NodeJoinTests extends ESTestCase {
|
||||||
try {
|
try {
|
||||||
final RequestHandlerRegistry<JoinRequest> joinHandler = (RequestHandlerRegistry<JoinRequest>)
|
final RequestHandlerRegistry<JoinRequest> joinHandler = (RequestHandlerRegistry<JoinRequest>)
|
||||||
transport.getRequestHandler(JoinHelper.JOIN_ACTION_NAME);
|
transport.getRequestHandler(JoinHelper.JOIN_ACTION_NAME);
|
||||||
joinHandler.processMessageReceived(joinRequest, new TransportChannel() {
|
final ActionListener<TransportResponse> listener = new ActionListener<TransportResponse>() {
|
||||||
@Override
|
|
||||||
public String getProfileName() {
|
|
||||||
return "dummy";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String getChannelType() {
|
public void onResponse(TransportResponse transportResponse) {
|
||||||
return "dummy";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void sendResponse(TransportResponse response) {
|
|
||||||
logger.debug("{} completed", future);
|
logger.debug("{} completed", future);
|
||||||
future.markAsDone();
|
future.markAsDone();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void sendResponse(Exception e) {
|
public void onFailure(Exception e) {
|
||||||
logger.error(() -> new ParameterizedMessage("unexpected error for {}", future), e);
|
logger.error(() -> new ParameterizedMessage("unexpected error for {}", future), e);
|
||||||
future.markAsFailed(e);
|
future.markAsFailed(e);
|
||||||
}
|
}
|
||||||
});
|
};
|
||||||
|
|
||||||
|
joinHandler.processMessageReceived(joinRequest, new TestTransportChannel(listener));
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
logger.error(() -> new ParameterizedMessage("unexpected error for {}", future), e);
|
logger.error(() -> new ParameterizedMessage("unexpected error for {}", future), e);
|
||||||
future.markAsFailed(e);
|
future.markAsFailed(e);
|
||||||
|
@ -402,27 +396,17 @@ public class NodeJoinTests extends ESTestCase {
|
||||||
private void handleStartJoinFrom(DiscoveryNode node, long term) throws Exception {
|
private void handleStartJoinFrom(DiscoveryNode node, long term) throws Exception {
|
||||||
final RequestHandlerRegistry<StartJoinRequest> startJoinHandler = (RequestHandlerRegistry<StartJoinRequest>)
|
final RequestHandlerRegistry<StartJoinRequest> startJoinHandler = (RequestHandlerRegistry<StartJoinRequest>)
|
||||||
transport.getRequestHandler(JoinHelper.START_JOIN_ACTION_NAME);
|
transport.getRequestHandler(JoinHelper.START_JOIN_ACTION_NAME);
|
||||||
startJoinHandler.processMessageReceived(new StartJoinRequest(node, term), new TransportChannel() {
|
startJoinHandler.processMessageReceived(new StartJoinRequest(node, term), new TestTransportChannel(
|
||||||
@Override
|
new ActionListener<TransportResponse>() {
|
||||||
public String getProfileName() {
|
@Override
|
||||||
return "dummy";
|
public void onResponse(TransportResponse transportResponse) {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String getChannelType() {
|
public void onFailure(Exception e) {
|
||||||
return "dummy";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void sendResponse(TransportResponse response) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void sendResponse(Exception exception) {
|
|
||||||
fail();
|
fail();
|
||||||
}
|
}
|
||||||
});
|
}));
|
||||||
deterministicTaskQueue.runAllRunnableTasks();
|
deterministicTaskQueue.runAllRunnableTasks();
|
||||||
assertFalse(isLocalNodeElectedMaster());
|
assertFalse(isLocalNodeElectedMaster());
|
||||||
assertThat(coordinator.getMode(), equalTo(Coordinator.Mode.CANDIDATE));
|
assertThat(coordinator.getMode(), equalTo(Coordinator.Mode.CANDIDATE));
|
||||||
|
@ -432,27 +416,19 @@ public class NodeJoinTests extends ESTestCase {
|
||||||
final RequestHandlerRegistry<FollowersChecker.FollowerCheckRequest> followerCheckHandler =
|
final RequestHandlerRegistry<FollowersChecker.FollowerCheckRequest> followerCheckHandler =
|
||||||
(RequestHandlerRegistry<FollowersChecker.FollowerCheckRequest>)
|
(RequestHandlerRegistry<FollowersChecker.FollowerCheckRequest>)
|
||||||
transport.getRequestHandler(FollowersChecker.FOLLOWER_CHECK_ACTION_NAME);
|
transport.getRequestHandler(FollowersChecker.FOLLOWER_CHECK_ACTION_NAME);
|
||||||
followerCheckHandler.processMessageReceived(new FollowersChecker.FollowerCheckRequest(term, node), new TransportChannel() {
|
final TestTransportChannel channel = new TestTransportChannel(new ActionListener<TransportResponse>() {
|
||||||
@Override
|
@Override
|
||||||
public String getProfileName() {
|
public void onResponse(TransportResponse transportResponse) {
|
||||||
return "dummy";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getChannelType() {
|
|
||||||
return "dummy";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void sendResponse(TransportResponse response) {
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void sendResponse(Exception exception) {
|
public void onFailure(Exception e) {
|
||||||
fail();
|
fail();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
followerCheckHandler.processMessageReceived(new FollowersChecker.FollowerCheckRequest(term, node), channel);
|
||||||
|
// Will throw exception if failed
|
||||||
deterministicTaskQueue.runAllRunnableTasks();
|
deterministicTaskQueue.runAllRunnableTasks();
|
||||||
assertFalse(isLocalNodeElectedMaster());
|
assertFalse(isLocalNodeElectedMaster());
|
||||||
assertThat(coordinator.getMode(), equalTo(Coordinator.Mode.FOLLOWER));
|
assertThat(coordinator.getMode(), equalTo(Coordinator.Mode.FOLLOWER));
|
||||||
|
|
|
@ -58,7 +58,7 @@ public class InboundHandlerTests extends ESTestCase {
|
||||||
taskManager = new TaskManager(Settings.EMPTY, threadPool, Collections.emptySet());
|
taskManager = new TaskManager(Settings.EMPTY, threadPool, Collections.emptySet());
|
||||||
channel = new FakeTcpChannel(randomBoolean(), buildNewFakeTransportAddress().address(), buildNewFakeTransportAddress().address());
|
channel = new FakeTcpChannel(randomBoolean(), buildNewFakeTransportAddress().address(), buildNewFakeTransportAddress().address());
|
||||||
NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList());
|
NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList());
|
||||||
TransportHandshaker handshaker = new TransportHandshaker(version, threadPool, (n, c, r, v) -> {}, (v, f, c, r, r_id) -> {});
|
TransportHandshaker handshaker = new TransportHandshaker(version, threadPool, (n, c, r, v) -> {});
|
||||||
TransportKeepAlive keepAlive = new TransportKeepAlive(threadPool, TcpChannel::sendMessage);
|
TransportKeepAlive keepAlive = new TransportKeepAlive(threadPool, TcpChannel::sendMessage);
|
||||||
OutboundHandler outboundHandler = new OutboundHandler("node", version, new String[0], new StatsTracker(), threadPool,
|
OutboundHandler outboundHandler = new OutboundHandler("node", version, new String[0], new StatsTracker(), threadPool,
|
||||||
BigArrays.NON_RECYCLING_INSTANCE);
|
BigArrays.NON_RECYCLING_INSTANCE);
|
||||||
|
|
|
@ -27,14 +27,12 @@ import org.elasticsearch.common.unit.TimeValue;
|
||||||
import org.elasticsearch.tasks.TaskId;
|
import org.elasticsearch.tasks.TaskId;
|
||||||
import org.elasticsearch.test.ESTestCase;
|
import org.elasticsearch.test.ESTestCase;
|
||||||
import org.elasticsearch.threadpool.TestThreadPool;
|
import org.elasticsearch.threadpool.TestThreadPool;
|
||||||
import org.mockito.ArgumentCaptor;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
import static org.hamcrest.Matchers.containsString;
|
import static org.hamcrest.Matchers.containsString;
|
||||||
import static org.mockito.Matchers.eq;
|
|
||||||
import static org.mockito.Mockito.doThrow;
|
import static org.mockito.Mockito.doThrow;
|
||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
import static org.mockito.Mockito.verify;
|
import static org.mockito.Mockito.verify;
|
||||||
|
@ -46,7 +44,6 @@ public class TransportHandshakerTests extends ESTestCase {
|
||||||
private TcpChannel channel;
|
private TcpChannel channel;
|
||||||
private TestThreadPool threadPool;
|
private TestThreadPool threadPool;
|
||||||
private TransportHandshaker.HandshakeRequestSender requestSender;
|
private TransportHandshaker.HandshakeRequestSender requestSender;
|
||||||
private TransportHandshaker.HandshakeResponseSender responseSender;
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setUp() throws Exception {
|
public void setUp() throws Exception {
|
||||||
|
@ -54,11 +51,10 @@ public class TransportHandshakerTests extends ESTestCase {
|
||||||
String nodeId = "node-id";
|
String nodeId = "node-id";
|
||||||
channel = mock(TcpChannel.class);
|
channel = mock(TcpChannel.class);
|
||||||
requestSender = mock(TransportHandshaker.HandshakeRequestSender.class);
|
requestSender = mock(TransportHandshaker.HandshakeRequestSender.class);
|
||||||
responseSender = mock(TransportHandshaker.HandshakeResponseSender.class);
|
|
||||||
node = new DiscoveryNode(nodeId, nodeId, nodeId, "host", "host_address", buildNewFakeTransportAddress(), Collections.emptyMap(),
|
node = new DiscoveryNode(nodeId, nodeId, nodeId, "host", "host_address", buildNewFakeTransportAddress(), Collections.emptyMap(),
|
||||||
Collections.emptySet(), Version.CURRENT);
|
Collections.emptySet(), Version.CURRENT);
|
||||||
threadPool = new TestThreadPool("thread-poll");
|
threadPool = new TestThreadPool("thread-poll");
|
||||||
handshaker = new TransportHandshaker(Version.CURRENT, threadPool, requestSender, responseSender);
|
handshaker = new TransportHandshaker(Version.CURRENT, threadPool, requestSender);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -76,20 +72,16 @@ public class TransportHandshakerTests extends ESTestCase {
|
||||||
|
|
||||||
assertFalse(versionFuture.isDone());
|
assertFalse(versionFuture.isDone());
|
||||||
|
|
||||||
TcpChannel mockChannel = mock(TcpChannel.class);
|
|
||||||
TransportHandshaker.HandshakeRequest handshakeRequest = new TransportHandshaker.HandshakeRequest(Version.CURRENT);
|
TransportHandshaker.HandshakeRequest handshakeRequest = new TransportHandshaker.HandshakeRequest(Version.CURRENT);
|
||||||
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
|
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
|
||||||
handshakeRequest.writeTo(bytesStreamOutput);
|
handshakeRequest.writeTo(bytesStreamOutput);
|
||||||
StreamInput input = bytesStreamOutput.bytes().streamInput();
|
StreamInput input = bytesStreamOutput.bytes().streamInput();
|
||||||
handshaker.handleHandshake(Version.CURRENT, Collections.emptySet(), mockChannel, reqId, input);
|
final PlainActionFuture<TransportResponse> responseFuture = PlainActionFuture.newFuture();
|
||||||
|
final TestTransportChannel channel = new TestTransportChannel(responseFuture);
|
||||||
|
handshaker.handleHandshake(channel, reqId, input);
|
||||||
ArgumentCaptor<TransportResponse> responseCaptor = ArgumentCaptor.forClass(TransportResponse.class);
|
|
||||||
verify(responseSender).sendResponse(eq(Version.CURRENT), eq(Collections.emptySet()), eq(mockChannel), responseCaptor.capture(),
|
|
||||||
eq(reqId));
|
|
||||||
|
|
||||||
TransportResponseHandler<TransportHandshaker.HandshakeResponse> handler = handshaker.removeHandlerForHandshake(reqId);
|
TransportResponseHandler<TransportHandshaker.HandshakeResponse> handler = handshaker.removeHandlerForHandshake(reqId);
|
||||||
handler.handleResponse((TransportHandshaker.HandshakeResponse) responseCaptor.getValue());
|
handler.handleResponse((TransportHandshaker.HandshakeResponse) responseFuture.actionGet());
|
||||||
|
|
||||||
assertTrue(versionFuture.isDone());
|
assertTrue(versionFuture.isDone());
|
||||||
assertEquals(Version.CURRENT, versionFuture.actionGet());
|
assertEquals(Version.CURRENT, versionFuture.actionGet());
|
||||||
|
@ -101,7 +93,6 @@ public class TransportHandshakerTests extends ESTestCase {
|
||||||
|
|
||||||
verify(requestSender).sendRequest(node, channel, reqId, Version.CURRENT.minimumCompatibilityVersion());
|
verify(requestSender).sendRequest(node, channel, reqId, Version.CURRENT.minimumCompatibilityVersion());
|
||||||
|
|
||||||
TcpChannel mockChannel = mock(TcpChannel.class);
|
|
||||||
TransportHandshaker.HandshakeRequest handshakeRequest = new TransportHandshaker.HandshakeRequest(Version.CURRENT);
|
TransportHandshaker.HandshakeRequest handshakeRequest = new TransportHandshaker.HandshakeRequest(Version.CURRENT);
|
||||||
BytesStreamOutput currentHandshakeBytes = new BytesStreamOutput();
|
BytesStreamOutput currentHandshakeBytes = new BytesStreamOutput();
|
||||||
handshakeRequest.writeTo(currentHandshakeBytes);
|
handshakeRequest.writeTo(currentHandshakeBytes);
|
||||||
|
@ -121,15 +112,12 @@ public class TransportHandshakerTests extends ESTestCase {
|
||||||
// Otherwise, we need to update the test.
|
// Otherwise, we need to update the test.
|
||||||
assertEquals(currentHandshakeBytes.bytes().length(), lengthCheckingHandshake.bytes().length());
|
assertEquals(currentHandshakeBytes.bytes().length(), lengthCheckingHandshake.bytes().length());
|
||||||
assertEquals(1031, futureHandshakeStream.available());
|
assertEquals(1031, futureHandshakeStream.available());
|
||||||
handshaker.handleHandshake(Version.CURRENT, Collections.emptySet(), mockChannel, reqId, futureHandshakeStream);
|
final PlainActionFuture<TransportResponse> responseFuture = PlainActionFuture.newFuture();
|
||||||
|
final TestTransportChannel channel = new TestTransportChannel(responseFuture);
|
||||||
|
handshaker.handleHandshake(channel, reqId, futureHandshakeStream);
|
||||||
assertEquals(0, futureHandshakeStream.available());
|
assertEquals(0, futureHandshakeStream.available());
|
||||||
|
|
||||||
|
TransportHandshaker.HandshakeResponse response = (TransportHandshaker.HandshakeResponse) responseFuture.actionGet();
|
||||||
ArgumentCaptor<TransportResponse> responseCaptor = ArgumentCaptor.forClass(TransportResponse.class);
|
|
||||||
verify(responseSender).sendResponse(eq(Version.CURRENT), eq(Collections.emptySet()), eq(mockChannel), responseCaptor.capture(),
|
|
||||||
eq(reqId));
|
|
||||||
|
|
||||||
TransportHandshaker.HandshakeResponse response = (TransportHandshaker.HandshakeResponse) responseCaptor.getValue();
|
|
||||||
|
|
||||||
assertEquals(Version.CURRENT, response.getResponseVersion());
|
assertEquals(Version.CURRENT, response.getResponseVersion());
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,51 @@
|
||||||
|
/*
|
||||||
|
* Licensed to Elasticsearch under one or more contributor
|
||||||
|
* license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright
|
||||||
|
* ownership. Elasticsearch licenses this file to you under
|
||||||
|
* the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
* not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on an
|
||||||
|
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
* KIND, either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.transport;
|
||||||
|
|
||||||
|
import org.elasticsearch.action.ActionListener;
|
||||||
|
|
||||||
|
public class TestTransportChannel implements TransportChannel {
|
||||||
|
|
||||||
|
private final ActionListener<TransportResponse> listener;
|
||||||
|
|
||||||
|
public TestTransportChannel(ActionListener<TransportResponse> listener) {
|
||||||
|
this.listener = listener;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getProfileName() {
|
||||||
|
return "default";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void sendResponse(TransportResponse response) {
|
||||||
|
listener.onResponse(response);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void sendResponse(Exception exception) {
|
||||||
|
listener.onFailure(exception);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getChannelType() {
|
||||||
|
return "test";
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue