Use TransportChannel in TransportHandshaker (#54921)
Currently the TransportHandshaker has a specialized codepath for sending a response. In other work, we are going to start having handshakes contribute to circuit breaking (while not being breakable). This commit moves in that direction by allowing the handshaker to responding using a standard TcpTransportChannel similar to other requests.
This commit is contained in:
parent
ce7ae4a7d1
commit
c7053ef824
|
@ -157,7 +157,10 @@ public class InboundHandler {
|
|||
try {
|
||||
messageListener.onRequestReceived(requestId, action);
|
||||
if (header.isHandshake()) {
|
||||
handshaker.handleHandshake(version, features, channel, requestId, stream);
|
||||
// Handshakes are not currently circuit broken
|
||||
transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, features,
|
||||
circuitBreakerService, 0, header.isCompressed(), header.isHandshake());
|
||||
handshaker.handleHandshake(transportChannel, requestId, stream);
|
||||
} else {
|
||||
final RequestHandlerRegistry<T> reg = getRequestHandler(action);
|
||||
if (reg == null) {
|
||||
|
@ -170,7 +173,7 @@ public class InboundHandler {
|
|||
breaker.addWithoutBreaking(messageLengthBytes);
|
||||
}
|
||||
transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, features,
|
||||
circuitBreakerService, messageLengthBytes, header.isCompressed());
|
||||
circuitBreakerService, messageLengthBytes, header.isCompressed(), header.isHandshake());
|
||||
final T request = reg.newRequest(stream);
|
||||
request.remoteAddress(new TransportAddress(channel.getRemoteAddress()));
|
||||
// in case we throw an exception, i.e. when the limit is hit, we don't want to verify
|
||||
|
@ -186,7 +189,7 @@ public class InboundHandler {
|
|||
// the circuit breaker tripped
|
||||
if (transportChannel == null) {
|
||||
transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, features,
|
||||
circuitBreakerService, 0, header.isCompressed());
|
||||
circuitBreakerService, 0, header.isCompressed(), header.isHandshake());
|
||||
}
|
||||
try {
|
||||
transportChannel.sendResponse(e);
|
||||
|
|
|
@ -159,9 +159,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
|
|||
this.handshaker = new TransportHandshaker(version, threadPool,
|
||||
(node, channel, requestId, v) -> outboundHandler.sendRequest(node, channel, requestId,
|
||||
TransportHandshaker.HANDSHAKE_ACTION_NAME, new TransportHandshaker.HandshakeRequest(version),
|
||||
TransportRequestOptions.EMPTY, v, false, true),
|
||||
(v, features1, channel, response, requestId) -> outboundHandler.sendResponse(v, features1, channel, requestId,
|
||||
TransportHandshaker.HANDSHAKE_ACTION_NAME, response, false, true));
|
||||
TransportRequestOptions.EMPTY, v, false, true));
|
||||
this.keepAlive = new TransportKeepAlive(threadPool, this.outboundHandler::sendBytes);
|
||||
this.inboundHandler = new InboundHandler(threadPool, outboundHandler, namedWriteableRegistry, circuitBreakerService, handshaker,
|
||||
keepAlive);
|
||||
|
|
|
@ -39,9 +39,11 @@ public final class TcpTransportChannel implements TransportChannel {
|
|||
private final CircuitBreakerService breakerService;
|
||||
private final long reservedBytes;
|
||||
private final boolean compressResponse;
|
||||
private final boolean isHandshake;
|
||||
|
||||
TcpTransportChannel(OutboundHandler outboundHandler, TcpChannel channel, String action, long requestId, Version version,
|
||||
Set<String> features, CircuitBreakerService breakerService, long reservedBytes, boolean compressResponse) {
|
||||
Set<String> features, CircuitBreakerService breakerService, long reservedBytes, boolean compressResponse,
|
||||
boolean isHandshake) {
|
||||
this.version = version;
|
||||
this.features = features;
|
||||
this.channel = channel;
|
||||
|
@ -51,6 +53,7 @@ public final class TcpTransportChannel implements TransportChannel {
|
|||
this.breakerService = breakerService;
|
||||
this.reservedBytes = reservedBytes;
|
||||
this.compressResponse = compressResponse;
|
||||
this.isHandshake = isHandshake;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -61,7 +64,7 @@ public final class TcpTransportChannel implements TransportChannel {
|
|||
@Override
|
||||
public void sendResponse(TransportResponse response) throws IOException {
|
||||
try {
|
||||
outboundHandler.sendResponse(version, features, channel, requestId, action, response, compressResponse, false);
|
||||
outboundHandler.sendResponse(version, features, channel, requestId, action, response, compressResponse, isHandshake);
|
||||
} finally {
|
||||
release(false);
|
||||
}
|
||||
|
@ -102,6 +105,5 @@ public final class TcpTransportChannel implements TransportChannel {
|
|||
public TcpChannel getChannel() {
|
||||
return channel;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -31,7 +31,6 @@ import org.elasticsearch.threadpool.ThreadPool;
|
|||
|
||||
import java.io.EOFException;
|
||||
import java.io.IOException;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.ConcurrentMap;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
@ -49,14 +48,11 @@ final class TransportHandshaker {
|
|||
private final Version version;
|
||||
private final ThreadPool threadPool;
|
||||
private final HandshakeRequestSender handshakeRequestSender;
|
||||
private final HandshakeResponseSender handshakeResponseSender;
|
||||
|
||||
TransportHandshaker(Version version, ThreadPool threadPool, HandshakeRequestSender handshakeRequestSender,
|
||||
HandshakeResponseSender handshakeResponseSender) {
|
||||
TransportHandshaker(Version version, ThreadPool threadPool, HandshakeRequestSender handshakeRequestSender) {
|
||||
this.version = version;
|
||||
this.threadPool = threadPool;
|
||||
this.handshakeRequestSender = handshakeRequestSender;
|
||||
this.handshakeResponseSender = handshakeResponseSender;
|
||||
}
|
||||
|
||||
void sendHandshake(long requestId, DiscoveryNode node, TcpChannel channel, TimeValue timeout, ActionListener<Version> listener) {
|
||||
|
@ -88,7 +84,7 @@ final class TransportHandshaker {
|
|||
}
|
||||
}
|
||||
|
||||
void handleHandshake(Version version, Set<String> features, TcpChannel channel, long requestId, StreamInput stream) throws IOException {
|
||||
void handleHandshake(TransportChannel channel, long requestId, StreamInput stream) throws IOException {
|
||||
// Must read the handshake request to exhaust the stream
|
||||
HandshakeRequest handshakeRequest = new HandshakeRequest(stream);
|
||||
final int nextByte = stream.read();
|
||||
|
@ -96,8 +92,7 @@ final class TransportHandshaker {
|
|||
throw new IllegalStateException("Handshake request not fully read for requestId [" + requestId + "], action ["
|
||||
+ TransportHandshaker.HANDSHAKE_ACTION_NAME + "], available [" + stream.available() + "]; resetting");
|
||||
}
|
||||
HandshakeResponse response = new HandshakeResponse(this.version);
|
||||
handshakeResponseSender.sendResponse(version, features, channel, response, requestId);
|
||||
channel.sendResponse(new HandshakeResponse(this.version));
|
||||
}
|
||||
|
||||
TransportResponseHandler<HandshakeResponse> removeHandlerForHandshake(long requestId) {
|
||||
|
@ -228,11 +223,4 @@ final class TransportHandshaker {
|
|||
|
||||
void sendRequest(DiscoveryNode node, TcpChannel channel, long requestId, Version version) throws IOException;
|
||||
}
|
||||
|
||||
@FunctionalInterface
|
||||
interface HandshakeResponseSender {
|
||||
|
||||
void sendResponse(Version version, Set<String> features, TcpChannel channel, TransportResponse response, long requestId)
|
||||
throws IOException;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -57,7 +57,7 @@ import org.elasticsearch.test.ESTestCase;
|
|||
import org.elasticsearch.test.transport.CapturingTransport;
|
||||
import org.elasticsearch.threadpool.TestThreadPool;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.transport.TransportChannel;
|
||||
import org.elasticsearch.transport.TestTransportChannel;
|
||||
import org.elasticsearch.transport.TransportResponse;
|
||||
import org.elasticsearch.transport.TransportService;
|
||||
import org.junit.After;
|
||||
|
@ -366,14 +366,15 @@ public class TransportBroadcastByNodeActionTests extends ESTestCase {
|
|||
final TransportBroadcastByNodeAction.BroadcastByNodeTransportRequestHandler handler =
|
||||
action.new BroadcastByNodeTransportRequestHandler();
|
||||
|
||||
TestTransportChannel channel = new TestTransportChannel();
|
||||
final PlainActionFuture<TransportResponse> future = PlainActionFuture.newFuture();
|
||||
TestTransportChannel channel = new TestTransportChannel(future);
|
||||
|
||||
handler.messageReceived(action.new NodeRequest(nodeId, new Request(), new ArrayList<>(shards)), channel, null);
|
||||
|
||||
// check the operation was executed only on the expected shards
|
||||
assertEquals(shards, action.getResults().keySet());
|
||||
|
||||
TransportResponse response = channel.getCapturedResponse();
|
||||
TransportResponse response = future.actionGet();
|
||||
assertTrue(response instanceof TransportBroadcastByNodeAction.NodeResponse);
|
||||
TransportBroadcastByNodeAction.NodeResponse nodeResponse = (TransportBroadcastByNodeAction.NodeResponse) response;
|
||||
|
||||
|
@ -469,32 +470,4 @@ public class TransportBroadcastByNodeActionTests extends ESTestCase {
|
|||
assertEquals("failed shards", totalFailedShards, response.getFailedShards());
|
||||
assertEquals("accumulated exceptions", totalFailedShards, response.getShardFailures().length);
|
||||
}
|
||||
|
||||
public class TestTransportChannel implements TransportChannel {
|
||||
private TransportResponse capturedResponse;
|
||||
|
||||
public TransportResponse getCapturedResponse() {
|
||||
return capturedResponse;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getProfileName() {
|
||||
return "";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void sendResponse(TransportResponse response) throws IOException {
|
||||
capturedResponse = response;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void sendResponse(Exception exception) throws IOException {
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getChannelType() {
|
||||
return "test";
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -78,6 +78,7 @@ import org.elasticsearch.test.transport.CapturingTransport;
|
|||
import org.elasticsearch.test.transport.MockTransportService;
|
||||
import org.elasticsearch.threadpool.TestThreadPool;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.transport.TestTransportChannel;
|
||||
import org.elasticsearch.transport.Transport;
|
||||
import org.elasticsearch.transport.TransportChannel;
|
||||
import org.elasticsearch.transport.TransportException;
|
||||
|
@ -817,7 +818,7 @@ public class TransportReplicationActionTests extends ESTestCase {
|
|||
Request request = new Request(shardId);
|
||||
TransportReplicationAction.ConcreteShardRequest<Request> concreteShardRequest =
|
||||
new TransportReplicationAction.ConcreteShardRequest<>(request, routingEntry.allocationId().getId(), primaryTerm);
|
||||
PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
|
||||
PlainActionFuture<TransportResponse> listener = new PlainActionFuture<>();
|
||||
|
||||
|
||||
final IndexShard shard = mockIndexShard(shardId, clusterService);
|
||||
|
@ -981,7 +982,7 @@ public class TransportReplicationActionTests extends ESTestCase {
|
|||
setState(clusterService, state(index, true, ShardRoutingState.STARTED));
|
||||
final ShardRouting primary = clusterService.state().routingTable().shardRoutingTable(shardId).primaryShard();
|
||||
final long primaryTerm = clusterService.state().metadata().index(shardId.getIndexName()).primaryTerm(shardId.id());
|
||||
PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
|
||||
PlainActionFuture<TransportResponse> listener = new PlainActionFuture<>();
|
||||
final boolean wrongAllocationId = randomBoolean();
|
||||
final long requestTerm = wrongAllocationId && randomBoolean() ? primaryTerm : primaryTerm + randomIntBetween(1, 10);
|
||||
Request request = new Request(shardId).timeout("1ms");
|
||||
|
@ -1018,7 +1019,7 @@ public class TransportReplicationActionTests extends ESTestCase {
|
|||
state = ClusterState.builder(state).nodes(DiscoveryNodes.builder(state.nodes()).localNodeId(replica.currentNodeId())).build();
|
||||
setState(clusterService, state);
|
||||
|
||||
PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
|
||||
PlainActionFuture<TransportResponse> listener = new PlainActionFuture<>();
|
||||
Request request = new Request(shardId).timeout("1ms");
|
||||
action.handleReplicaRequest(
|
||||
new TransportReplicationAction.ConcreteReplicaRequest<>(request, "_not_a_valid_aid_", randomNonNegativeLong(),
|
||||
|
@ -1062,7 +1063,7 @@ public class TransportReplicationActionTests extends ESTestCase {
|
|||
return new ReplicaResult();
|
||||
}
|
||||
};
|
||||
final PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
|
||||
final PlainActionFuture<TransportResponse> listener = new PlainActionFuture<>();
|
||||
final Request request = new Request(shardId);
|
||||
final long checkpoint = randomNonNegativeLong();
|
||||
final long maxSeqNoOfUpdatesOrDeletes = randomNonNegativeLong();
|
||||
|
@ -1130,7 +1131,7 @@ public class TransportReplicationActionTests extends ESTestCase {
|
|||
return new ReplicaResult();
|
||||
}
|
||||
};
|
||||
final PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
|
||||
final PlainActionFuture<TransportResponse> listener = new PlainActionFuture<>();
|
||||
final Request request = new Request(shardId);
|
||||
final long checkpoint = randomNonNegativeLong();
|
||||
final long maxSeqNoOfUpdates = randomNonNegativeLong();
|
||||
|
@ -1371,29 +1372,8 @@ public class TransportReplicationActionTests extends ESTestCase {
|
|||
/**
|
||||
* Transport channel that is needed for replica operation testing.
|
||||
*/
|
||||
public TransportChannel createTransportChannel(final PlainActionFuture<TestResponse> listener) {
|
||||
return new TransportChannel() {
|
||||
|
||||
@Override
|
||||
public String getProfileName() {
|
||||
return "";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void sendResponse(TransportResponse response) {
|
||||
listener.onResponse(((TestResponse) response));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void sendResponse(Exception exception) {
|
||||
listener.onFailure(exception);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getChannelType() {
|
||||
return "replica_test";
|
||||
}
|
||||
};
|
||||
public TransportChannel createTransportChannel(final PlainActionFuture<TransportResponse> listener) {
|
||||
return new TestTransportChannel(listener);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ package org.elasticsearch.cluster.coordination;
|
|||
|
||||
import org.apache.logging.log4j.message.ParameterizedMessage;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.cluster.ClusterName;
|
||||
import org.elasticsearch.cluster.ClusterState;
|
||||
import org.elasticsearch.cluster.ESAllocationTestCase;
|
||||
|
@ -44,8 +45,8 @@ import org.elasticsearch.test.transport.CapturingTransport;
|
|||
import org.elasticsearch.threadpool.TestThreadPool;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.transport.RequestHandlerRegistry;
|
||||
import org.elasticsearch.transport.TestTransportChannel;
|
||||
import org.elasticsearch.transport.Transport;
|
||||
import org.elasticsearch.transport.TransportChannel;
|
||||
import org.elasticsearch.transport.TransportRequest;
|
||||
import org.elasticsearch.transport.TransportResponse;
|
||||
import org.elasticsearch.transport.TransportService;
|
||||
|
@ -229,29 +230,22 @@ public class NodeJoinTests extends ESTestCase {
|
|||
try {
|
||||
final RequestHandlerRegistry<JoinRequest> joinHandler = (RequestHandlerRegistry<JoinRequest>)
|
||||
transport.getRequestHandler(JoinHelper.JOIN_ACTION_NAME);
|
||||
joinHandler.processMessageReceived(joinRequest, new TransportChannel() {
|
||||
@Override
|
||||
public String getProfileName() {
|
||||
return "dummy";
|
||||
}
|
||||
final ActionListener<TransportResponse> listener = new ActionListener<TransportResponse>() {
|
||||
|
||||
@Override
|
||||
public String getChannelType() {
|
||||
return "dummy";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void sendResponse(TransportResponse response) {
|
||||
public void onResponse(TransportResponse transportResponse) {
|
||||
logger.debug("{} completed", future);
|
||||
future.markAsDone();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void sendResponse(Exception e) {
|
||||
public void onFailure(Exception e) {
|
||||
logger.error(() -> new ParameterizedMessage("unexpected error for {}", future), e);
|
||||
future.markAsFailed(e);
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
joinHandler.processMessageReceived(joinRequest, new TestTransportChannel(listener));
|
||||
} catch (Exception e) {
|
||||
logger.error(() -> new ParameterizedMessage("unexpected error for {}", future), e);
|
||||
future.markAsFailed(e);
|
||||
|
@ -402,27 +396,17 @@ public class NodeJoinTests extends ESTestCase {
|
|||
private void handleStartJoinFrom(DiscoveryNode node, long term) throws Exception {
|
||||
final RequestHandlerRegistry<StartJoinRequest> startJoinHandler = (RequestHandlerRegistry<StartJoinRequest>)
|
||||
transport.getRequestHandler(JoinHelper.START_JOIN_ACTION_NAME);
|
||||
startJoinHandler.processMessageReceived(new StartJoinRequest(node, term), new TransportChannel() {
|
||||
startJoinHandler.processMessageReceived(new StartJoinRequest(node, term), new TestTransportChannel(
|
||||
new ActionListener<TransportResponse>() {
|
||||
@Override
|
||||
public String getProfileName() {
|
||||
return "dummy";
|
||||
public void onResponse(TransportResponse transportResponse) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getChannelType() {
|
||||
return "dummy";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void sendResponse(TransportResponse response) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void sendResponse(Exception exception) {
|
||||
public void onFailure(Exception e) {
|
||||
fail();
|
||||
}
|
||||
});
|
||||
}));
|
||||
deterministicTaskQueue.runAllRunnableTasks();
|
||||
assertFalse(isLocalNodeElectedMaster());
|
||||
assertThat(coordinator.getMode(), equalTo(Coordinator.Mode.CANDIDATE));
|
||||
|
@ -432,27 +416,19 @@ public class NodeJoinTests extends ESTestCase {
|
|||
final RequestHandlerRegistry<FollowersChecker.FollowerCheckRequest> followerCheckHandler =
|
||||
(RequestHandlerRegistry<FollowersChecker.FollowerCheckRequest>)
|
||||
transport.getRequestHandler(FollowersChecker.FOLLOWER_CHECK_ACTION_NAME);
|
||||
followerCheckHandler.processMessageReceived(new FollowersChecker.FollowerCheckRequest(term, node), new TransportChannel() {
|
||||
final TestTransportChannel channel = new TestTransportChannel(new ActionListener<TransportResponse>() {
|
||||
@Override
|
||||
public String getProfileName() {
|
||||
return "dummy";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getChannelType() {
|
||||
return "dummy";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void sendResponse(TransportResponse response) {
|
||||
public void onResponse(TransportResponse transportResponse) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void sendResponse(Exception exception) {
|
||||
public void onFailure(Exception e) {
|
||||
fail();
|
||||
}
|
||||
});
|
||||
followerCheckHandler.processMessageReceived(new FollowersChecker.FollowerCheckRequest(term, node), channel);
|
||||
// Will throw exception if failed
|
||||
deterministicTaskQueue.runAllRunnableTasks();
|
||||
assertFalse(isLocalNodeElectedMaster());
|
||||
assertThat(coordinator.getMode(), equalTo(Coordinator.Mode.FOLLOWER));
|
||||
|
|
|
@ -58,7 +58,7 @@ public class InboundHandlerTests extends ESTestCase {
|
|||
taskManager = new TaskManager(Settings.EMPTY, threadPool, Collections.emptySet());
|
||||
channel = new FakeTcpChannel(randomBoolean(), buildNewFakeTransportAddress().address(), buildNewFakeTransportAddress().address());
|
||||
NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList());
|
||||
TransportHandshaker handshaker = new TransportHandshaker(version, threadPool, (n, c, r, v) -> {}, (v, f, c, r, r_id) -> {});
|
||||
TransportHandshaker handshaker = new TransportHandshaker(version, threadPool, (n, c, r, v) -> {});
|
||||
TransportKeepAlive keepAlive = new TransportKeepAlive(threadPool, TcpChannel::sendMessage);
|
||||
OutboundHandler outboundHandler = new OutboundHandler("node", version, new String[0], new StatsTracker(), threadPool,
|
||||
BigArrays.NON_RECYCLING_INSTANCE);
|
||||
|
|
|
@ -27,14 +27,12 @@ import org.elasticsearch.common.unit.TimeValue;
|
|||
import org.elasticsearch.tasks.TaskId;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.threadpool.TestThreadPool;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.mockito.Matchers.eq;
|
||||
import static org.mockito.Mockito.doThrow;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
|
@ -46,7 +44,6 @@ public class TransportHandshakerTests extends ESTestCase {
|
|||
private TcpChannel channel;
|
||||
private TestThreadPool threadPool;
|
||||
private TransportHandshaker.HandshakeRequestSender requestSender;
|
||||
private TransportHandshaker.HandshakeResponseSender responseSender;
|
||||
|
||||
@Override
|
||||
public void setUp() throws Exception {
|
||||
|
@ -54,11 +51,10 @@ public class TransportHandshakerTests extends ESTestCase {
|
|||
String nodeId = "node-id";
|
||||
channel = mock(TcpChannel.class);
|
||||
requestSender = mock(TransportHandshaker.HandshakeRequestSender.class);
|
||||
responseSender = mock(TransportHandshaker.HandshakeResponseSender.class);
|
||||
node = new DiscoveryNode(nodeId, nodeId, nodeId, "host", "host_address", buildNewFakeTransportAddress(), Collections.emptyMap(),
|
||||
Collections.emptySet(), Version.CURRENT);
|
||||
threadPool = new TestThreadPool("thread-poll");
|
||||
handshaker = new TransportHandshaker(Version.CURRENT, threadPool, requestSender, responseSender);
|
||||
handshaker = new TransportHandshaker(Version.CURRENT, threadPool, requestSender);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -76,20 +72,16 @@ public class TransportHandshakerTests extends ESTestCase {
|
|||
|
||||
assertFalse(versionFuture.isDone());
|
||||
|
||||
TcpChannel mockChannel = mock(TcpChannel.class);
|
||||
TransportHandshaker.HandshakeRequest handshakeRequest = new TransportHandshaker.HandshakeRequest(Version.CURRENT);
|
||||
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
|
||||
handshakeRequest.writeTo(bytesStreamOutput);
|
||||
StreamInput input = bytesStreamOutput.bytes().streamInput();
|
||||
handshaker.handleHandshake(Version.CURRENT, Collections.emptySet(), mockChannel, reqId, input);
|
||||
|
||||
|
||||
ArgumentCaptor<TransportResponse> responseCaptor = ArgumentCaptor.forClass(TransportResponse.class);
|
||||
verify(responseSender).sendResponse(eq(Version.CURRENT), eq(Collections.emptySet()), eq(mockChannel), responseCaptor.capture(),
|
||||
eq(reqId));
|
||||
final PlainActionFuture<TransportResponse> responseFuture = PlainActionFuture.newFuture();
|
||||
final TestTransportChannel channel = new TestTransportChannel(responseFuture);
|
||||
handshaker.handleHandshake(channel, reqId, input);
|
||||
|
||||
TransportResponseHandler<TransportHandshaker.HandshakeResponse> handler = handshaker.removeHandlerForHandshake(reqId);
|
||||
handler.handleResponse((TransportHandshaker.HandshakeResponse) responseCaptor.getValue());
|
||||
handler.handleResponse((TransportHandshaker.HandshakeResponse) responseFuture.actionGet());
|
||||
|
||||
assertTrue(versionFuture.isDone());
|
||||
assertEquals(Version.CURRENT, versionFuture.actionGet());
|
||||
|
@ -101,7 +93,6 @@ public class TransportHandshakerTests extends ESTestCase {
|
|||
|
||||
verify(requestSender).sendRequest(node, channel, reqId, Version.CURRENT.minimumCompatibilityVersion());
|
||||
|
||||
TcpChannel mockChannel = mock(TcpChannel.class);
|
||||
TransportHandshaker.HandshakeRequest handshakeRequest = new TransportHandshaker.HandshakeRequest(Version.CURRENT);
|
||||
BytesStreamOutput currentHandshakeBytes = new BytesStreamOutput();
|
||||
handshakeRequest.writeTo(currentHandshakeBytes);
|
||||
|
@ -121,15 +112,12 @@ public class TransportHandshakerTests extends ESTestCase {
|
|||
// Otherwise, we need to update the test.
|
||||
assertEquals(currentHandshakeBytes.bytes().length(), lengthCheckingHandshake.bytes().length());
|
||||
assertEquals(1031, futureHandshakeStream.available());
|
||||
handshaker.handleHandshake(Version.CURRENT, Collections.emptySet(), mockChannel, reqId, futureHandshakeStream);
|
||||
final PlainActionFuture<TransportResponse> responseFuture = PlainActionFuture.newFuture();
|
||||
final TestTransportChannel channel = new TestTransportChannel(responseFuture);
|
||||
handshaker.handleHandshake(channel, reqId, futureHandshakeStream);
|
||||
assertEquals(0, futureHandshakeStream.available());
|
||||
|
||||
|
||||
ArgumentCaptor<TransportResponse> responseCaptor = ArgumentCaptor.forClass(TransportResponse.class);
|
||||
verify(responseSender).sendResponse(eq(Version.CURRENT), eq(Collections.emptySet()), eq(mockChannel), responseCaptor.capture(),
|
||||
eq(reqId));
|
||||
|
||||
TransportHandshaker.HandshakeResponse response = (TransportHandshaker.HandshakeResponse) responseCaptor.getValue();
|
||||
TransportHandshaker.HandshakeResponse response = (TransportHandshaker.HandshakeResponse) responseFuture.actionGet();
|
||||
|
||||
assertEquals(Version.CURRENT, response.getResponseVersion());
|
||||
}
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.transport;
|
||||
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
|
||||
public class TestTransportChannel implements TransportChannel {
|
||||
|
||||
private final ActionListener<TransportResponse> listener;
|
||||
|
||||
public TestTransportChannel(ActionListener<TransportResponse> listener) {
|
||||
this.listener = listener;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getProfileName() {
|
||||
return "default";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void sendResponse(TransportResponse response) {
|
||||
listener.onResponse(response);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void sendResponse(Exception exception) {
|
||||
listener.onFailure(exception);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getChannelType() {
|
||||
return "test";
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue