Introduce mechanism to stub request handling (#55832)

Currently there is a clear mechanism to stub sending a request through
the transport. However, this is limited to testing exceptions on the
sender side. This commit reworks our transport related testing
infrastructure to allow stubbing request handling on the receiving side.
This commit is contained in:
Tim Brooks 2020-04-27 16:57:15 -06:00 committed by GitHub
parent 2ff858b290
commit 80662f31a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 305 additions and 221 deletions

View File

@ -32,8 +32,8 @@ import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.InboundHandler;
import org.elasticsearch.transport.InboundPipeline;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.Transports;
import java.nio.channels.ClosedChannelException;
@ -56,9 +56,9 @@ final class Netty4MessageChannelHandler extends ChannelDuplexHandler {
Netty4MessageChannelHandler(PageCacheRecycler recycler, Netty4Transport transport) {
this.transport = transport;
final ThreadPool threadPool = transport.getThreadPool();
final InboundHandler inboundHandler = transport.getInboundHandler();
final Transport.RequestHandlers requestHandlers = transport.getRequestHandlers();
this.pipeline = new InboundPipeline(transport.getVersion(), transport.getStatsTracker(), recycler, threadPool::relativeTimeInMillis,
transport.getInflightBreaker(), inboundHandler::getRequestHandler, transport::inboundMessage);
transport.getInflightBreaker(), requestHandlers::getHandler, transport::inboundMessage);
}
@Override

View File

@ -31,9 +31,9 @@ import org.elasticsearch.nio.BytesWriteHandler;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.Page;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.InboundHandler;
import org.elasticsearch.transport.InboundPipeline;
import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.transport.Transport;
import java.io.IOException;
import java.util.function.Supplier;
@ -47,9 +47,9 @@ public class TcpReadWriteHandler extends BytesWriteHandler {
this.channel = channel;
final ThreadPool threadPool = transport.getThreadPool();
final Supplier<CircuitBreaker> breaker = transport.getInflightBreaker();
final InboundHandler inboundHandler = transport.getInboundHandler();
final Transport.RequestHandlers requestHandlers = transport.getRequestHandlers();
this.pipeline = new InboundPipeline(transport.getVersion(), transport.getStatsTracker(), recycler, threadPool::relativeTimeInMillis,
breaker, inboundHandler::getRequestHandler, transport::inboundMessage);
breaker, requestHandlers::getHandler, transport::inboundMessage);
}
@Override

View File

@ -23,7 +23,6 @@ import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.Version;
import org.elasticsearch.common.collect.MapBuilder;
import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
@ -34,8 +33,6 @@ import org.elasticsearch.threadpool.ThreadPool;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.Collections;
import java.util.Map;
public class InboundHandler {
@ -46,34 +43,21 @@ public class InboundHandler {
private final NamedWriteableRegistry namedWriteableRegistry;
private final TransportHandshaker handshaker;
private final TransportKeepAlive keepAlive;
private final Transport.ResponseHandlers responseHandlers;
private final Transport.RequestHandlers requestHandlers;
private final Transport.ResponseHandlers responseHandlers = new Transport.ResponseHandlers();
private volatile Map<String, RequestHandlerRegistry<? extends TransportRequest>> requestHandlers = Collections.emptyMap();
private volatile TransportMessageListener messageListener = TransportMessageListener.NOOP_LISTENER;
InboundHandler(ThreadPool threadPool, OutboundHandler outboundHandler, NamedWriteableRegistry namedWriteableRegistry,
TransportHandshaker handshaker, TransportKeepAlive keepAlive) {
TransportHandshaker handshaker, TransportKeepAlive keepAlive, Transport.RequestHandlers requestHandlers,
Transport.ResponseHandlers responseHandlers) {
this.threadPool = threadPool;
this.outboundHandler = outboundHandler;
this.namedWriteableRegistry = namedWriteableRegistry;
this.handshaker = handshaker;
this.keepAlive = keepAlive;
}
synchronized <Request extends TransportRequest> void registerRequestHandler(RequestHandlerRegistry<Request> reg) {
if (requestHandlers.containsKey(reg.getAction())) {
throw new IllegalArgumentException("transport handlers for action " + reg.getAction() + " is already registered");
}
requestHandlers = MapBuilder.newMapBuilder(requestHandlers).put(reg.getAction(), reg).immutableMap();
}
@SuppressWarnings("unchecked")
public final <T extends TransportRequest> RequestHandlerRegistry<T> getRequestHandler(String action) {
return (RequestHandlerRegistry<T>) requestHandlers.get(action);
}
final Transport.ResponseHandlers getResponseHandlers() {
return responseHandlers;
this.requestHandlers = requestHandlers;
this.responseHandlers = responseHandlers;
}
void setMessageListener(TransportMessageListener listener) {
@ -171,7 +155,7 @@ public class InboundHandler {
} else {
final StreamInput stream = namedWriteableStream(message.openOrGetStreamInput());
assertRemoteVersion(stream, header.getVersion());
final RequestHandlerRegistry<T> reg = getRequestHandler(action);
final RequestHandlerRegistry<T> reg = requestHandlers.getHandler(action);
assert reg != null;
final T request = reg.newRequest(stream);
request.remoteAddress(new TransportAddress(channel.getRemoteAddress()));

View File

@ -81,9 +81,18 @@ public class RequestHandlerRegistry<Request extends TransportRequest> {
return executor;
}
public TransportRequestHandler<Request> getHandler() {
return handler;
}
@Override
public String toString() {
return handler.toString();
}
public static <R extends TransportRequest> RequestHandlerRegistry<R> replaceHandler(RequestHandlerRegistry<R> registry,
TransportRequestHandler<R> handler) {
return new RequestHandlerRegistry<>(registry.action, registry.requestReader, registry.taskManager, handler,
registry.executor, registry.forceExecution, registry.canTripCircuitBreaker);
}
}

View File

@ -129,7 +129,9 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
private final TransportHandshaker handshaker;
private final TransportKeepAlive keepAlive;
private final OutboundHandler outboundHandler;
protected final InboundHandler inboundHandler;
private final InboundHandler inboundHandler;
private final ResponseHandlers responseHandlers = new ResponseHandlers();
private final RequestHandlers requestHandlers = new RequestHandlers();
public TcpTransport(Settings settings, Version version, ThreadPool threadPool, PageCacheRecycler pageCacheRecycler,
CircuitBreakerService circuitBreakerService, NamedWriteableRegistry namedWriteableRegistry,
@ -163,7 +165,8 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
TransportHandshaker.HANDSHAKE_ACTION_NAME, new TransportHandshaker.HandshakeRequest(version),
TransportRequestOptions.EMPTY, v, false, true));
this.keepAlive = new TransportKeepAlive(threadPool, this.outboundHandler::sendBytes);
this.inboundHandler = new InboundHandler(threadPool, outboundHandler, namedWriteableRegistry, handshaker, keepAlive);
this.inboundHandler = new InboundHandler(threadPool, outboundHandler, namedWriteableRegistry, handshaker, keepAlive,
requestHandlers, responseHandlers);
}
public Version getVersion() {
@ -182,10 +185,6 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
return () -> circuitBreakerService.getBreaker(CircuitBreaker.IN_FLIGHT_REQUESTS);
}
public InboundHandler getInboundHandler() {
return inboundHandler;
}
@Override
protected void doStart() {
}
@ -196,11 +195,6 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
inboundHandler.setMessageListener(listener);
}
@Override
public synchronized <Request extends TransportRequest> void registerRequestHandler(RequestHandlerRegistry<Request> reg) {
inboundHandler.registerRequestHandler(reg);
}
public final class NodeChannels extends CloseableConnection {
private final Map<TransportRequestOptions.Type, ConnectionProfile.ConnectionTypeHandle> typeMapping;
private final List<TcpChannel> channels;
@ -813,7 +807,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
}
public void executeHandshake(DiscoveryNode node, TcpChannel channel, ConnectionProfile profile, ActionListener<Version> listener) {
long requestId = inboundHandler.getResponseHandlers().newRequestId();
long requestId = responseHandlers.newRequestId();
handshaker.sendHandshake(requestId, node, channel, profile.getHandshakeTimeout(), listener);
}
@ -917,12 +911,12 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
@Override
public final ResponseHandlers getResponseHandlers() {
return inboundHandler.getResponseHandlers();
return responseHandlers;
}
@Override
public final RequestHandlerRegistry<? extends TransportRequest> getRequestHandler(String action) {
return inboundHandler.getRequestHandler(action);
public final RequestHandlers getRequestHandlers() {
return requestHandlers;
}
private final class ChannelsConnectedListener implements ActionListener<Void> {

View File

@ -22,6 +22,7 @@ package org.elasticsearch.transport;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.collect.MapBuilder;
import org.elasticsearch.common.component.LifecycleComponent;
import org.elasticsearch.common.transport.BoundTransportAddress;
import org.elasticsearch.common.transport.TransportAddress;
@ -32,6 +33,7 @@ import java.io.Closeable;
import java.io.IOException;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
@ -42,13 +44,9 @@ public interface Transport extends LifecycleComponent {
/**
* Registers a new request handler
*/
<Request extends TransportRequest> void registerRequestHandler(RequestHandlerRegistry<Request> reg);
/**
* Returns the registered request handler registry for the given action or <code>null</code> if it's not registered
* @param action the action to look up
*/
RequestHandlerRegistry<? extends TransportRequest> getRequestHandler(String action);
default <Request extends TransportRequest> void registerRequestHandler(RequestHandlerRegistry<Request> reg) {
getRequestHandlers().registerHandler(reg);
}
void setMessageListener(TransportMessageListener listener);
@ -87,6 +85,8 @@ public interface Transport extends LifecycleComponent {
ResponseHandlers getResponseHandlers();
RequestHandlers getRequestHandlers();
/**
* A unidirectional connection to a {@link DiscoveryNode}
*/
@ -187,7 +187,7 @@ public interface Transport extends LifecycleComponent {
* Removes and return the {@link ResponseContext} for the given request ID or returns
* <code>null</code> if no context is associated with this request ID.
*/
public ResponseContext remove(long requestId) {
public ResponseContext<? extends TransportResponse> remove(long requestId) {
return handlers.remove(requestId);
}
@ -198,7 +198,7 @@ public interface Transport extends LifecycleComponent {
*/
public long add(ResponseContext<? extends TransportResponse> holder) {
long requestId = newRequestId();
ResponseContext existing = handlers.put(requestId, holder);
ResponseContext<? extends TransportResponse> existing = handlers.put(requestId, holder);
assert existing == null : "request ID already in use: " + requestId;
return requestId;
}
@ -214,12 +214,12 @@ public interface Transport extends LifecycleComponent {
/**
* Removes and returns all {@link ResponseContext} instances that match the predicate
*/
public List<ResponseContext<? extends TransportResponse>> prune(Predicate<ResponseContext> predicate) {
public List<ResponseContext<? extends TransportResponse>> prune(Predicate<ResponseContext<? extends TransportResponse>> predicate) {
final List<ResponseContext<? extends TransportResponse>> holders = new ArrayList<>();
for (Map.Entry<Long, ResponseContext<? extends TransportResponse>> entry : handlers.entrySet()) {
ResponseContext<? extends TransportResponse> holder = entry.getValue();
if (predicate.test(holder)) {
ResponseContext remove = handlers.remove(entry.getKey());
ResponseContext<? extends TransportResponse> remove = handlers.remove(entry.getKey());
if (remove != null) {
holders.add(holder);
}
@ -244,4 +244,27 @@ public interface Transport extends LifecycleComponent {
}
}
}
final class RequestHandlers {
private volatile Map<String, RequestHandlerRegistry<? extends TransportRequest>> requestHandlers = Collections.emptyMap();
synchronized <Request extends TransportRequest> void registerHandler(RequestHandlerRegistry<Request> reg) {
if (requestHandlers.containsKey(reg.getAction())) {
throw new IllegalArgumentException("transport handlers for action " + reg.getAction() + " is already registered");
}
requestHandlers = MapBuilder.newMapBuilder(requestHandlers).put(reg.getAction(), reg).immutableMap();
}
// TODO: Only visible for testing. Perhaps move StubbableTransport from
// org.elasticsearch.test.transport to org.elasticsearch.transport
public synchronized <Request extends TransportRequest> void forceRegister(RequestHandlerRegistry<Request> reg) {
requestHandlers = MapBuilder.newMapBuilder(requestHandlers).put(reg.getAction(), reg).immutableMap();
}
@SuppressWarnings("unchecked")
public <T extends TransportRequest> RequestHandlerRegistry<T> getHandler(String action) {
return (RequestHandlerRegistry<T>) requestHandlers.get(action);
}
}
}

View File

@ -740,7 +740,7 @@ public class TransportService extends AbstractLifecycleComponent implements Repo
} catch (final Exception e) {
// usually happen either because we failed to connect to the node
// or because we failed serializing the message
final Transport.ResponseContext contextToNotify = responseHandlers.remove(requestId);
final Transport.ResponseContext<? extends TransportResponse> contextToNotify = responseHandlers.remove(requestId);
// If holderToNotify == null then handler has already been taken care of.
if (contextToNotify != null) {
if (timeoutHandler != null) {
@ -986,7 +986,7 @@ public class TransportService extends AbstractLifecycleComponent implements Repo
}
public RequestHandlerRegistry<? extends TransportRequest> getRequestHandler(String action) {
return transport.getRequestHandler(action);
return transport.getRequestHandlers().getHandler(action);
}
private void checkForTimeout(long requestId) {
@ -1065,7 +1065,7 @@ public class TransportService extends AbstractLifecycleComponent implements Repo
long timeoutTime = threadPool.relativeTimeInMillis();
timeoutInfoHandlers.put(requestId, new TimeoutInfoHolder(node, action, sentTime, timeoutTime));
// now that we have the information visible via timeoutInfoHandlers, we try to remove the request id
final Transport.ResponseContext holder = responseHandlers.remove(requestId);
final Transport.ResponseContext<? extends TransportResponse> holder = responseHandlers.remove(requestId);
if (holder != null) {
assert holder.action().equals(action);
assert holder.connection().getNode().equals(node);

View File

@ -28,7 +28,6 @@ import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.collect.MapBuilder;
import org.elasticsearch.common.component.Lifecycle;
import org.elasticsearch.common.component.LifecycleListener;
import org.elasticsearch.common.settings.Settings;
@ -37,7 +36,6 @@ import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.transport.CloseableConnection;
import org.elasticsearch.transport.ConnectTransportException;
import org.elasticsearch.transport.ConnectionProfile;
import org.elasticsearch.transport.RequestHandlerRegistry;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportMessageListener;
@ -60,7 +58,7 @@ abstract class FailAndRetryMockTransport<Response extends TransportResponse> imp
private final Random random;
private final ClusterName clusterName;
private volatile Map<String, RequestHandlerRegistry> requestHandlers = Collections.emptyMap();
private final RequestHandlers requestHandlers = new RequestHandlers();
private final Object requestHandlerMutex = new Object();
private final ResponseHandlers responseHandlers = new ResponseHandlers();
private TransportMessageListener listener;
@ -205,26 +203,16 @@ abstract class FailAndRetryMockTransport<Response extends TransportResponse> imp
throw new UnsupportedOperationException();
}
@Override
public <Request extends TransportRequest> void registerRequestHandler(RequestHandlerRegistry<Request> reg) {
synchronized (requestHandlerMutex) {
if (requestHandlers.containsKey(reg.getAction())) {
throw new IllegalArgumentException("transport handlers for action " + reg.getAction() + " is already registered");
}
requestHandlers = MapBuilder.newMapBuilder(requestHandlers).put(reg.getAction(), reg).immutableMap();
}
}
@Override
public ResponseHandlers getResponseHandlers() {
return responseHandlers;
}
@Override
public RequestHandlerRegistry getRequestHandler(String action) {
return requestHandlers.get(action);
public RequestHandlers getRequestHandlers() {
return requestHandlers;
}
@Override
public void setMessageListener(TransportMessageListener listener) {
this.listener = listener;

View File

@ -44,7 +44,6 @@ import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.ConnectTransportException;
import org.elasticsearch.transport.ConnectionProfile;
import org.elasticsearch.transport.RequestHandlerRegistry;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportMessageListener;
@ -505,8 +504,9 @@ public class NodeConnectionsServiceTests extends ESTestCase {
}
}
private final class MockTransport implements Transport {
private ResponseHandlers responseHandlers = new ResponseHandlers();
private static final class MockTransport implements Transport {
private final ResponseHandlers responseHandlers = new ResponseHandlers();
private final RequestHandlers requestHandlers = new RequestHandlers();
private volatile boolean randomConnectionExceptions = false;
private final ThreadPool threadPool;
@ -514,16 +514,6 @@ public class NodeConnectionsServiceTests extends ESTestCase {
this.threadPool = threadPool;
}
@Override
public <Request extends TransportRequest> void registerRequestHandler(RequestHandlerRegistry<Request> reg) {
}
@SuppressWarnings("unchecked")
@Override
public RequestHandlerRegistry getRequestHandler(String action) {
return null;
}
@Override
public void setMessageListener(TransportMessageListener listener) {
}
@ -614,5 +604,10 @@ public class NodeConnectionsServiceTests extends ESTestCase {
public ResponseHandlers getResponseHandlers() {
return responseHandlers;
}
@Override
public RequestHandlers getRequestHandlers() {
return requestHandlers;
}
}
}

View File

@ -229,8 +229,7 @@ public class NodeJoinTests extends ESTestCase {
// clone the node before submitting to simulate an incoming join, which is guaranteed to have a new
// disco node object serialized off the network
try {
final RequestHandlerRegistry<JoinRequest> joinHandler = (RequestHandlerRegistry<JoinRequest>)
transport.getRequestHandler(JoinHelper.JOIN_ACTION_NAME);
final RequestHandlerRegistry<JoinRequest> joinHandler = transport.getRequestHandlers().getHandler(JoinHelper.JOIN_ACTION_NAME);
final ActionListener<TransportResponse> listener = new ActionListener<TransportResponse>() {
@Override
@ -434,8 +433,8 @@ public class NodeJoinTests extends ESTestCase {
}
private void handleStartJoinFrom(DiscoveryNode node, long term) throws Exception {
final RequestHandlerRegistry<StartJoinRequest> startJoinHandler = (RequestHandlerRegistry<StartJoinRequest>)
transport.getRequestHandler(JoinHelper.START_JOIN_ACTION_NAME);
final RequestHandlerRegistry<StartJoinRequest> startJoinHandler = transport.getRequestHandlers()
.getHandler(JoinHelper.START_JOIN_ACTION_NAME);
startJoinHandler.processMessageReceived(new StartJoinRequest(node, term), new TestTransportChannel(
new ActionListener<TransportResponse>() {
@Override
@ -453,9 +452,8 @@ public class NodeJoinTests extends ESTestCase {
}
private void handleFollowerCheckFrom(DiscoveryNode node, long term) throws Exception {
final RequestHandlerRegistry<FollowersChecker.FollowerCheckRequest> followerCheckHandler =
(RequestHandlerRegistry<FollowersChecker.FollowerCheckRequest>)
transport.getRequestHandler(FollowersChecker.FOLLOWER_CHECK_ACTION_NAME);
final RequestHandlerRegistry<FollowersChecker.FollowerCheckRequest> followerCheckHandler = transport.getRequestHandlers()
.getHandler(FollowersChecker.FOLLOWER_CHECK_ACTION_NAME);
final TestTransportChannel channel = new TestTransportChannel(new ActionListener<TransportResponse>() {
@Override
public void onResponse(TransportResponse transportResponse) {

View File

@ -43,16 +43,17 @@ import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.WriteRequest.RefreshPolicy;
import org.elasticsearch.action.support.replication.ReplicationResponse;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.NodeConnectionsService;
import org.elasticsearch.cluster.action.shard.ShardStateAction;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.routing.IndexShardRoutingTable;
import org.elasticsearch.cluster.routing.ShardRoutingState;
import org.elasticsearch.cluster.routing.RecoverySource;
import org.elasticsearch.cluster.routing.RecoverySource.PeerRecoverySource;
import org.elasticsearch.cluster.routing.RecoverySource.SnapshotRecoverySource;
import org.elasticsearch.cluster.routing.ShardRouting;
import org.elasticsearch.cluster.routing.ShardRoutingState;
import org.elasticsearch.cluster.routing.UnassignedInfo;
import org.elasticsearch.cluster.routing.allocation.command.AllocateEmptyPrimaryAllocationCommand;
import org.elasticsearch.cluster.routing.allocation.command.MoveAllocationCommand;
@ -136,6 +137,7 @@ import static org.elasticsearch.action.DocWriteResponse.Result.UPDATED;
import static org.elasticsearch.node.RecoverySettingsChunkSizePlugin.CHUNK_SIZE_SETTING;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
@ -144,7 +146,6 @@ import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.isOneOf;
import static org.hamcrest.Matchers.lessThanOrEqualTo;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.empty;
@ClusterScope(scope = Scope.TEST, numDataNodes = 0)
public class IndexRecoveryIT extends ESIntegTestCase {
@ -721,6 +722,7 @@ public class IndexRecoveryIT extends ESIntegTestCase {
final Settings nodeSettings = Settings.builder()
.put(RecoverySettings.INDICES_RECOVERY_RETRY_DELAY_NETWORK_SETTING.getKey(), "100ms")
.put(RecoverySettings.INDICES_RECOVERY_INTERNAL_ACTION_TIMEOUT_SETTING.getKey(), "1s")
.put(NodeConnectionsService.CLUSTER_NODE_RECONNECT_INTERVAL_SETTING.getKey(), "1s")
.build();
// start a master node
internalCluster().startNode(nodeSettings);
@ -777,12 +779,29 @@ public class IndexRecoveryIT extends ESIntegTestCase {
(MockTransportService) internalCluster().getInstance(TransportService.class, redNodeName);
TransportService redTransportService = internalCluster().getInstance(TransportService.class, redNodeName);
TransportService blueTransportService = internalCluster().getInstance(TransportService.class, blueNodeName);
final CountDownLatch requestBlocked = new CountDownLatch(1);
final CountDownLatch requestFailed = new CountDownLatch(1);
blueMockTransportService.addSendBehavior(redTransportService,
new RecoveryActionBlocker(dropRequests, recoveryActionToBlock, requestBlocked));
redMockTransportService.addSendBehavior(blueTransportService,
new RecoveryActionBlocker(dropRequests, recoveryActionToBlock, requestBlocked));
if (randomBoolean()) {
// Fail on the sending side
blueMockTransportService.addSendBehavior(redTransportService,
new RecoveryActionBlocker(dropRequests, recoveryActionToBlock, requestFailed));
redMockTransportService.addSendBehavior(blueTransportService,
new RecoveryActionBlocker(dropRequests, recoveryActionToBlock, requestFailed));
} else {
// Fail on the receiving side.
blueMockTransportService.addRequestHandlingBehavior(recoveryActionToBlock, (handler, request, channel, task) -> {
logger.info("--> preventing {} response by closing response channel", recoveryActionToBlock);
requestFailed.countDown();
redMockTransportService.disconnectFromNode(blueMockTransportService.getLocalDiscoNode());
handler.messageReceived(request, channel, task);
});
redMockTransportService.addRequestHandlingBehavior(recoveryActionToBlock, (handler, request, channel, task) -> {
logger.info("--> preventing {} response by closing response channel", recoveryActionToBlock);
requestFailed.countDown();
blueMockTransportService.disconnectFromNode(redMockTransportService.getLocalDiscoNode());
handler.messageReceived(request, channel, task);
});
}
logger.info("--> starting recovery from blue to red");
client().admin().indices().prepareUpdateSettings(indexName).setSettings(
@ -791,9 +810,9 @@ public class IndexRecoveryIT extends ESIntegTestCase {
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 1)
).get();
requestBlocked.await();
requestFailed.await();
logger.info("--> stopping to block recovery");
logger.info("--> clearing rules to allow recovery to proceed");
blueMockTransportService.clearAllRules();
redMockTransportService.clearAllRules();
@ -818,12 +837,14 @@ public class IndexRecoveryIT extends ESIntegTestCase {
public void sendRequest(Transport.Connection connection, long requestId, String action, TransportRequest request,
TransportRequestOptions options) throws IOException {
if (recoveryActionToBlock.equals(action) || requestBlocked.getCount() == 0) {
logger.info("--> preventing {} request", action);
requestBlocked.countDown();
if (dropRequests) {
logger.info("--> preventing {} request by dropping request", action);
return;
} else {
logger.info("--> preventing {} request by throwing ConnectTransportException", action);
throw new ConnectTransportException(connection.getNode(), "DISCONNECT: prevented " + action + " request");
}
throw new ConnectTransportException(connection.getNode(), "DISCONNECT: prevented " + action + " request");
}
connection.sendRequest(requestId, action, request, options);
}

View File

@ -48,6 +48,8 @@ public class InboundHandlerTests extends ESTestCase {
private final Version version = Version.CURRENT;
private TaskManager taskManager;
private Transport.ResponseHandlers responseHandlers;
private Transport.RequestHandlers requestHandlers;
private InboundHandler handler;
private FakeTcpChannel channel;
@ -61,7 +63,10 @@ public class InboundHandlerTests extends ESTestCase {
TransportKeepAlive keepAlive = new TransportKeepAlive(threadPool, TcpChannel::sendMessage);
OutboundHandler outboundHandler = new OutboundHandler("node", version, new String[0], new StatsTracker(), threadPool,
BigArrays.NON_RECYCLING_INSTANCE);
handler = new InboundHandler(threadPool, outboundHandler, namedWriteableRegistry, handshaker, keepAlive);
requestHandlers = new Transport.RequestHandlers();
responseHandlers = new Transport.ResponseHandlers();
handler = new InboundHandler(threadPool, outboundHandler, namedWriteableRegistry, handshaker, keepAlive, requestHandlers,
responseHandlers);
}
@After
@ -74,7 +79,7 @@ public class InboundHandlerTests extends ESTestCase {
AtomicReference<TransportChannel> channelCaptor = new AtomicReference<>();
RequestHandlerRegistry<TestRequest> registry = new RequestHandlerRegistry<>("test-request", TestRequest::new, taskManager,
(request, channel, task) -> channelCaptor.set(channel), ThreadPool.Names.SAME, false, true);
handler.registerRequestHandler(registry);
requestHandlers.registerHandler(registry);
handler.inboundMessage(channel, new InboundMessage(null, true));
if (channel.isServerChannel()) {
@ -93,7 +98,7 @@ public class InboundHandlerTests extends ESTestCase {
AtomicReference<Exception> exceptionCaptor = new AtomicReference<>();
AtomicReference<TransportChannel> channelCaptor = new AtomicReference<>();
long requestId = handler.getResponseHandlers().add(new Transport.ResponseContext<>(new TransportResponseHandler<TestResponse>() {
long requestId = responseHandlers.add(new Transport.ResponseContext<>(new TransportResponseHandler<TestResponse>() {
@Override
public void handleResponse(TestResponse response) {
responseCaptor.set(response);
@ -119,7 +124,7 @@ public class InboundHandlerTests extends ESTestCase {
channelCaptor.set(channel);
requestCaptor.set(request);
}, ThreadPool.Names.SAME, false, true);
handler.registerRequestHandler(registry);
requestHandlers.registerHandler(registry);
String requestValue = randomAlphaOfLength(10);
OutboundMessage.Request request = new OutboundMessage.Request(threadPool.getThreadContext(), new String[0],
new TestRequest(requestValue), version, action, requestId, false, false);

View File

@ -168,7 +168,7 @@ public abstract class DisruptableMockTransport extends MockTransport {
protected void onConnectedDuringSend(long requestId, String action, TransportRequest request,
DisruptableMockTransport destinationTransport) {
final RequestHandlerRegistry<TransportRequest> requestHandler =
destinationTransport.getRequestHandler(action);
destinationTransport.getRequestHandlers().getHandler(action);
final DiscoveryNode destination = destinationTransport.getLocalNode();

View File

@ -0,0 +1,120 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.test.transport;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.component.AbstractLifecycleComponent;
import org.elasticsearch.common.transport.BoundTransportAddress;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.transport.CloseableConnection;
import org.elasticsearch.transport.ConnectionProfile;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportMessageListener;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportRequestOptions;
import org.elasticsearch.transport.TransportStats;
import java.util.Collections;
import java.util.List;
import java.util.Map;
/**
* A transport that does nothing. Normally wrapped by {@link StubbableTransport}.
*/
public class FakeTransport extends AbstractLifecycleComponent implements Transport {
private final RequestHandlers requestHandlers = new RequestHandlers();
private final ResponseHandlers responseHandlers = new ResponseHandlers();
private TransportMessageListener listener;
@Override
public void setMessageListener(TransportMessageListener listener) {
if (this.listener != null) {
throw new IllegalStateException("listener already set");
}
this.listener = listener;
}
@Override
public BoundTransportAddress boundAddress() {
return null;
}
@Override
public Map<String, BoundTransportAddress> profileBoundAddresses() {
return null;
}
@Override
public TransportAddress[] addressesFromString(String address) {
return new TransportAddress[0];
}
@Override
public List<String> getDefaultSeedAddresses() {
return Collections.emptyList();
}
@Override
public void openConnection(DiscoveryNode node, ConnectionProfile profile, ActionListener<Connection> listener) {
listener.onResponse(new CloseableConnection() {
@Override
public DiscoveryNode getNode() {
return node;
}
@Override
public void sendRequest(long requestId, String action, TransportRequest request, TransportRequestOptions options) {
}
});
}
@Override
public TransportStats getStats() {
return null;
}
@Override
public ResponseHandlers getResponseHandlers() {
return responseHandlers;
}
@Override
public RequestHandlers getRequestHandlers() {
return requestHandlers;
}
@Override
protected void doStart() {
}
@Override
protected void doStop() {
}
@Override
protected void doClose() {
}
}

View File

@ -19,31 +19,22 @@
package org.elasticsearch.test.transport;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.ClusterModule;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.collect.MapBuilder;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.component.Lifecycle;
import org.elasticsearch.common.component.LifecycleComponent;
import org.elasticsearch.common.component.LifecycleListener;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.BoundTransportAddress;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.CloseableConnection;
import org.elasticsearch.transport.ClusterConnectionManager;
import org.elasticsearch.transport.ConnectionProfile;
import org.elasticsearch.transport.RemoteTransportException;
import org.elasticsearch.transport.RequestHandlerRegistry;
import org.elasticsearch.transport.SendRequestTransportException;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportInterceptor;
import org.elasticsearch.transport.TransportMessageListener;
@ -52,12 +43,8 @@ import org.elasticsearch.transport.TransportRequestOptions;
import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.transport.TransportResponseHandler;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.transport.TransportStats;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
@ -68,11 +55,8 @@ import static org.apache.lucene.util.LuceneTestCase.rarely;
/**
* A basic transport implementation that allows to intercept requests that have been sent
*/
public class MockTransport implements Transport, LifecycleComponent {
public class MockTransport extends StubbableTransport {
private volatile Map<String, RequestHandlerRegistry> requestHandlers = Collections.emptyMap();
private final Object requestHandlerMutex = new Object();
private final ResponseHandlers responseHandlers = new ResponseHandlers();
private TransportMessageListener listener;
private ConcurrentMap<Long, Tuple<DiscoveryNode, String>> requests = new ConcurrentHashMap<>();
@ -86,13 +70,18 @@ public class MockTransport implements Transport, LifecycleComponent {
connectionManager);
}
public MockTransport() {
super(new FakeTransport());
setDefaultConnectBehavior((transport, discoveryNode, profile, listener) -> listener.onResponse(createConnection(discoveryNode)));
}
/**
* simulate a response for the given requestId
*/
@SuppressWarnings("unchecked")
public <Response extends TransportResponse> void handleResponse(final long requestId, final Response response) {
final TransportResponseHandler<Response> transportResponseHandler =
(TransportResponseHandler<Response>) responseHandlers.onResponseReceived(requestId, listener);
(TransportResponseHandler<Response>) getResponseHandlers().onResponseReceived(requestId, listener);
if (transportResponseHandler != null) {
final Response deliveredResponse;
try (BytesStreamOutput output = new BytesStreamOutput()) {
@ -155,17 +144,12 @@ public class MockTransport implements Transport, LifecycleComponent {
* @param e the failure
*/
public void handleError(final long requestId, final TransportException e) {
final TransportResponseHandler transportResponseHandler = responseHandlers.onResponseReceived(requestId, listener);
final TransportResponseHandler transportResponseHandler = getResponseHandlers().onResponseReceived(requestId, listener);
if (transportResponseHandler != null) {
transportResponseHandler.handleException(e);
}
}
@Override
public void openConnection(DiscoveryNode node, ConnectionProfile profile, ActionListener<Connection> listener) {
listener.onResponse(createConnection(node));
}
public Connection createConnection(DiscoveryNode node) {
return new CloseableConnection() {
@Override
@ -185,83 +169,13 @@ public class MockTransport implements Transport, LifecycleComponent {
protected void onSendRequest(long requestId, String action, TransportRequest request, DiscoveryNode node) {
}
@Override
public TransportStats getStats() {
throw new UnsupportedOperationException();
}
@Override
public BoundTransportAddress boundAddress() {
return null;
}
@Override
public Map<String, BoundTransportAddress> profileBoundAddresses() {
return null;
}
@Override
public TransportAddress[] addressesFromString(String address) {
return new TransportAddress[0];
}
@Override
public Lifecycle.State lifecycleState() {
return null;
}
@Override
public void addLifecycleListener(LifecycleListener listener) {
}
@Override
public void removeLifecycleListener(LifecycleListener listener) {
}
@Override
public void start() {
}
@Override
public void stop() {
}
@Override
public void close() {
}
@Override
public List<String> getDefaultSeedAddresses() {
return Collections.emptyList();
}
@Override
public <Request extends TransportRequest> void registerRequestHandler(RequestHandlerRegistry<Request> reg) {
synchronized (requestHandlerMutex) {
if (requestHandlers.containsKey(reg.getAction())) {
throw new IllegalArgumentException("transport handlers for action " + reg.getAction() + " is already registered");
}
requestHandlers = MapBuilder.newMapBuilder(requestHandlers).put(reg.getAction(), reg).immutableMap();
}
}
@Override
public ResponseHandlers getResponseHandlers() {
return responseHandlers;
}
@SuppressWarnings("unchecked")
@Override
public RequestHandlerRegistry<TransportRequest> getRequestHandler(String action) {
return requestHandlers.get(action);
}
@Override
public void setMessageListener(TransportMessageListener listener) {
if (this.listener != null) {
throw new IllegalStateException("listener already set");
}
this.listener = listener;
super.setMessageListener(listener);
}
protected NamedWriteableRegistry writeableRegistry() {

View File

@ -415,6 +415,15 @@ public final class MockTransportService extends TransportService {
});
}
/**
* Adds a new handling behavior that is used when the defined request is received.
*
*/
public <R extends TransportRequest> void addRequestHandlingBehavior(String actionName,
StubbableTransport.RequestHandlingBehavior<R> handlingBehavior) {
transport().addRequestHandlingBehavior(actionName, handlingBehavior);
}
/**
* Adds a new send behavior that is used for communication with the given delegate service.
*

View File

@ -26,12 +26,15 @@ import org.elasticsearch.common.component.Lifecycle;
import org.elasticsearch.common.component.LifecycleListener;
import org.elasticsearch.common.transport.BoundTransportAddress;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.ConnectionProfile;
import org.elasticsearch.transport.RequestHandlerRegistry;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportMessageListener;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportRequestHandler;
import org.elasticsearch.transport.TransportRequestOptions;
import org.elasticsearch.transport.TransportStats;
@ -41,10 +44,11 @@ import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
public final class StubbableTransport implements Transport {
public class StubbableTransport implements Transport {
private final ConcurrentHashMap<TransportAddress, SendRequestBehavior> sendBehaviors = new ConcurrentHashMap<>();
private final ConcurrentHashMap<TransportAddress, OpenConnectionBehavior> connectBehaviors = new ConcurrentHashMap<>();
private final ConcurrentHashMap<String, RequestHandlerRegistry<?>> replacedRequestRegistries = new ConcurrentHashMap<>();
private volatile SendRequestBehavior defaultSendRequest = null;
private volatile OpenConnectionBehavior defaultConnectBehavior = null;
private final Transport delegate;
@ -74,11 +78,28 @@ public final class StubbableTransport implements Transport {
return connectBehaviors.put(transportAddress, connectBehavior) == null;
}
<Request extends TransportRequest> void addRequestHandlingBehavior(String actionName, RequestHandlingBehavior<Request> behavior) {
final RequestHandlers requestHandlers = delegate.getRequestHandlers();
final RequestHandlerRegistry<Request> realRegistry = requestHandlers.getHandler(actionName);
if (realRegistry == null) {
throw new IllegalStateException("Cannot find registered action for: " + actionName);
}
replacedRequestRegistries.put(actionName, realRegistry);
final TransportRequestHandler<Request> realHandler = realRegistry.getHandler();
final RequestHandlerRegistry<Request> newRegistry = RequestHandlerRegistry.replaceHandler(realRegistry, (request, channel, task) ->
behavior.messageReceived(realHandler, request, channel, task));
requestHandlers.forceRegister(newRegistry);
}
void clearBehaviors() {
this.defaultSendRequest = null;
sendBehaviors.clear();
this.defaultConnectBehavior = null;
connectBehaviors.clear();
for (Map.Entry<String, RequestHandlerRegistry<?>> entry : replacedRequestRegistries.entrySet()) {
getRequestHandlers().forceRegister(entry.getValue());
}
replacedRequestRegistries.clear();
}
void clearBehavior(TransportAddress transportAddress) {
@ -101,16 +122,6 @@ public final class StubbableTransport implements Transport {
delegate.setMessageListener(listener);
}
@Override
public <Request extends TransportRequest> void registerRequestHandler(RequestHandlerRegistry<Request> reg) {
delegate.registerRequestHandler(reg);
}
@Override
public RequestHandlerRegistry getRequestHandler(String action) {
return delegate.getRequestHandler(action);
}
@Override
public BoundTransportAddress boundAddress() {
return delegate.boundAddress();
@ -152,6 +163,11 @@ public final class StubbableTransport implements Transport {
return delegate.getResponseHandlers();
}
@Override
public RequestHandlers getRequestHandlers() {
return delegate.getRequestHandlers();
}
@Override
public Lifecycle.State lifecycleState() {
return delegate.lifecycleState();
@ -257,7 +273,15 @@ public final class StubbableTransport implements Transport {
void sendRequest(Connection connection, long requestId, String action, TransportRequest request,
TransportRequestOptions options) throws IOException;
default void clearCallback() {
}
default void clearCallback() {}
}
@FunctionalInterface
public interface RequestHandlingBehavior<Request extends TransportRequest> {
void messageReceived(TransportRequestHandler<Request> handler, Request request, TransportChannel channel, Task task)
throws Exception;
default void clearCallback() {}
}
}

View File

@ -54,7 +54,6 @@ import org.elasticsearch.nio.Page;
import org.elasticsearch.nio.ServerChannelContext;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.ConnectionProfile;
import org.elasticsearch.transport.InboundHandler;
import org.elasticsearch.transport.InboundPipeline;
import org.elasticsearch.transport.StatsTracker;
import org.elasticsearch.transport.TcpChannel;
@ -278,11 +277,11 @@ public class MockNioTransport extends TcpTransport {
this.channel = channel;
final ThreadPool threadPool = transport.getThreadPool();
final Supplier<CircuitBreaker> breaker = transport.getInflightBreaker();
final InboundHandler inboundHandler = transport.getInboundHandler();
final RequestHandlers requestHandlers = transport.getRequestHandlers();
final Version version = transport.getVersion();
final StatsTracker statsTracker = transport.getStatsTracker();
this.pipeline = new InboundPipeline(version, statsTracker, recycler, threadPool::relativeTimeInMillis, breaker,
inboundHandler::getRequestHandler, transport::inboundMessage);
requestHandlers::getHandler, transport::inboundMessage);
}
@Override

View File

@ -22,7 +22,8 @@ public class SecurityServerTransportServiceTests extends SecurityIntegTestCase {
public void testSecurityServerTransportServiceWrapsAllHandlers() {
for (TransportService transportService : internalCluster().getInstances(TransportService.class)) {
RequestHandlerRegistry handler = transportService.transport.getRequestHandler(TransportService.HANDSHAKE_ACTION_NAME);
RequestHandlerRegistry handler = transportService.transport.getRequestHandlers()
.getHandler(TransportService.HANDSHAKE_ACTION_NAME);
assertEquals(
"handler not wrapped by " + SecurityServerTransportInterceptor.ProfileSecuredRequestHandler.class +
"; do all the handler registration methods have overrides?",