Introduce mechanism to stub request handling (#55832)

Currently there is a clear mechanism to stub sending a request through
the transport. However, this is limited to testing exceptions on the
sender side. This commit reworks our transport related testing
infrastructure to allow stubbing request handling on the receiving side.
This commit is contained in:
Tim Brooks 2020-04-27 16:57:15 -06:00 committed by GitHub
parent 2ff858b290
commit 80662f31a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 305 additions and 221 deletions

View File

@ -32,8 +32,8 @@ import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.common.lease.Releasables; import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.common.util.PageCacheRecycler; import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.InboundHandler;
import org.elasticsearch.transport.InboundPipeline; import org.elasticsearch.transport.InboundPipeline;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.Transports; import org.elasticsearch.transport.Transports;
import java.nio.channels.ClosedChannelException; import java.nio.channels.ClosedChannelException;
@ -56,9 +56,9 @@ final class Netty4MessageChannelHandler extends ChannelDuplexHandler {
Netty4MessageChannelHandler(PageCacheRecycler recycler, Netty4Transport transport) { Netty4MessageChannelHandler(PageCacheRecycler recycler, Netty4Transport transport) {
this.transport = transport; this.transport = transport;
final ThreadPool threadPool = transport.getThreadPool(); final ThreadPool threadPool = transport.getThreadPool();
final InboundHandler inboundHandler = transport.getInboundHandler(); final Transport.RequestHandlers requestHandlers = transport.getRequestHandlers();
this.pipeline = new InboundPipeline(transport.getVersion(), transport.getStatsTracker(), recycler, threadPool::relativeTimeInMillis, this.pipeline = new InboundPipeline(transport.getVersion(), transport.getStatsTracker(), recycler, threadPool::relativeTimeInMillis,
transport.getInflightBreaker(), inboundHandler::getRequestHandler, transport::inboundMessage); transport.getInflightBreaker(), requestHandlers::getHandler, transport::inboundMessage);
} }
@Override @Override

View File

@ -31,9 +31,9 @@ import org.elasticsearch.nio.BytesWriteHandler;
import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.Page; import org.elasticsearch.nio.Page;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.InboundHandler;
import org.elasticsearch.transport.InboundPipeline; import org.elasticsearch.transport.InboundPipeline;
import org.elasticsearch.transport.TcpTransport; import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.transport.Transport;
import java.io.IOException; import java.io.IOException;
import java.util.function.Supplier; import java.util.function.Supplier;
@ -47,9 +47,9 @@ public class TcpReadWriteHandler extends BytesWriteHandler {
this.channel = channel; this.channel = channel;
final ThreadPool threadPool = transport.getThreadPool(); final ThreadPool threadPool = transport.getThreadPool();
final Supplier<CircuitBreaker> breaker = transport.getInflightBreaker(); final Supplier<CircuitBreaker> breaker = transport.getInflightBreaker();
final InboundHandler inboundHandler = transport.getInboundHandler(); final Transport.RequestHandlers requestHandlers = transport.getRequestHandlers();
this.pipeline = new InboundPipeline(transport.getVersion(), transport.getStatsTracker(), recycler, threadPool::relativeTimeInMillis, this.pipeline = new InboundPipeline(transport.getVersion(), transport.getStatsTracker(), recycler, threadPool::relativeTimeInMillis,
breaker, inboundHandler::getRequestHandler, transport::inboundMessage); breaker, requestHandlers::getHandler, transport::inboundMessage);
} }
@Override @Override

View File

@ -23,7 +23,6 @@ import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.Version; import org.elasticsearch.Version;
import org.elasticsearch.common.collect.MapBuilder;
import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
@ -34,8 +33,6 @@ import org.elasticsearch.threadpool.ThreadPool;
import java.io.IOException; import java.io.IOException;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.util.Collections;
import java.util.Map;
public class InboundHandler { public class InboundHandler {
@ -46,34 +43,21 @@ public class InboundHandler {
private final NamedWriteableRegistry namedWriteableRegistry; private final NamedWriteableRegistry namedWriteableRegistry;
private final TransportHandshaker handshaker; private final TransportHandshaker handshaker;
private final TransportKeepAlive keepAlive; private final TransportKeepAlive keepAlive;
private final Transport.ResponseHandlers responseHandlers;
private final Transport.RequestHandlers requestHandlers;
private final Transport.ResponseHandlers responseHandlers = new Transport.ResponseHandlers();
private volatile Map<String, RequestHandlerRegistry<? extends TransportRequest>> requestHandlers = Collections.emptyMap();
private volatile TransportMessageListener messageListener = TransportMessageListener.NOOP_LISTENER; private volatile TransportMessageListener messageListener = TransportMessageListener.NOOP_LISTENER;
InboundHandler(ThreadPool threadPool, OutboundHandler outboundHandler, NamedWriteableRegistry namedWriteableRegistry, InboundHandler(ThreadPool threadPool, OutboundHandler outboundHandler, NamedWriteableRegistry namedWriteableRegistry,
TransportHandshaker handshaker, TransportKeepAlive keepAlive) { TransportHandshaker handshaker, TransportKeepAlive keepAlive, Transport.RequestHandlers requestHandlers,
Transport.ResponseHandlers responseHandlers) {
this.threadPool = threadPool; this.threadPool = threadPool;
this.outboundHandler = outboundHandler; this.outboundHandler = outboundHandler;
this.namedWriteableRegistry = namedWriteableRegistry; this.namedWriteableRegistry = namedWriteableRegistry;
this.handshaker = handshaker; this.handshaker = handshaker;
this.keepAlive = keepAlive; this.keepAlive = keepAlive;
} this.requestHandlers = requestHandlers;
this.responseHandlers = responseHandlers;
synchronized <Request extends TransportRequest> void registerRequestHandler(RequestHandlerRegistry<Request> reg) {
if (requestHandlers.containsKey(reg.getAction())) {
throw new IllegalArgumentException("transport handlers for action " + reg.getAction() + " is already registered");
}
requestHandlers = MapBuilder.newMapBuilder(requestHandlers).put(reg.getAction(), reg).immutableMap();
}
@SuppressWarnings("unchecked")
public final <T extends TransportRequest> RequestHandlerRegistry<T> getRequestHandler(String action) {
return (RequestHandlerRegistry<T>) requestHandlers.get(action);
}
final Transport.ResponseHandlers getResponseHandlers() {
return responseHandlers;
} }
void setMessageListener(TransportMessageListener listener) { void setMessageListener(TransportMessageListener listener) {
@ -171,7 +155,7 @@ public class InboundHandler {
} else { } else {
final StreamInput stream = namedWriteableStream(message.openOrGetStreamInput()); final StreamInput stream = namedWriteableStream(message.openOrGetStreamInput());
assertRemoteVersion(stream, header.getVersion()); assertRemoteVersion(stream, header.getVersion());
final RequestHandlerRegistry<T> reg = getRequestHandler(action); final RequestHandlerRegistry<T> reg = requestHandlers.getHandler(action);
assert reg != null; assert reg != null;
final T request = reg.newRequest(stream); final T request = reg.newRequest(stream);
request.remoteAddress(new TransportAddress(channel.getRemoteAddress())); request.remoteAddress(new TransportAddress(channel.getRemoteAddress()));

View File

@ -81,9 +81,18 @@ public class RequestHandlerRegistry<Request extends TransportRequest> {
return executor; return executor;
} }
public TransportRequestHandler<Request> getHandler() {
return handler;
}
@Override @Override
public String toString() { public String toString() {
return handler.toString(); return handler.toString();
} }
public static <R extends TransportRequest> RequestHandlerRegistry<R> replaceHandler(RequestHandlerRegistry<R> registry,
TransportRequestHandler<R> handler) {
return new RequestHandlerRegistry<>(registry.action, registry.requestReader, registry.taskManager, handler,
registry.executor, registry.forceExecution, registry.canTripCircuitBreaker);
}
} }

View File

@ -129,7 +129,9 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
private final TransportHandshaker handshaker; private final TransportHandshaker handshaker;
private final TransportKeepAlive keepAlive; private final TransportKeepAlive keepAlive;
private final OutboundHandler outboundHandler; private final OutboundHandler outboundHandler;
protected final InboundHandler inboundHandler; private final InboundHandler inboundHandler;
private final ResponseHandlers responseHandlers = new ResponseHandlers();
private final RequestHandlers requestHandlers = new RequestHandlers();
public TcpTransport(Settings settings, Version version, ThreadPool threadPool, PageCacheRecycler pageCacheRecycler, public TcpTransport(Settings settings, Version version, ThreadPool threadPool, PageCacheRecycler pageCacheRecycler,
CircuitBreakerService circuitBreakerService, NamedWriteableRegistry namedWriteableRegistry, CircuitBreakerService circuitBreakerService, NamedWriteableRegistry namedWriteableRegistry,
@ -163,7 +165,8 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
TransportHandshaker.HANDSHAKE_ACTION_NAME, new TransportHandshaker.HandshakeRequest(version), TransportHandshaker.HANDSHAKE_ACTION_NAME, new TransportHandshaker.HandshakeRequest(version),
TransportRequestOptions.EMPTY, v, false, true)); TransportRequestOptions.EMPTY, v, false, true));
this.keepAlive = new TransportKeepAlive(threadPool, this.outboundHandler::sendBytes); this.keepAlive = new TransportKeepAlive(threadPool, this.outboundHandler::sendBytes);
this.inboundHandler = new InboundHandler(threadPool, outboundHandler, namedWriteableRegistry, handshaker, keepAlive); this.inboundHandler = new InboundHandler(threadPool, outboundHandler, namedWriteableRegistry, handshaker, keepAlive,
requestHandlers, responseHandlers);
} }
public Version getVersion() { public Version getVersion() {
@ -182,10 +185,6 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
return () -> circuitBreakerService.getBreaker(CircuitBreaker.IN_FLIGHT_REQUESTS); return () -> circuitBreakerService.getBreaker(CircuitBreaker.IN_FLIGHT_REQUESTS);
} }
public InboundHandler getInboundHandler() {
return inboundHandler;
}
@Override @Override
protected void doStart() { protected void doStart() {
} }
@ -196,11 +195,6 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
inboundHandler.setMessageListener(listener); inboundHandler.setMessageListener(listener);
} }
@Override
public synchronized <Request extends TransportRequest> void registerRequestHandler(RequestHandlerRegistry<Request> reg) {
inboundHandler.registerRequestHandler(reg);
}
public final class NodeChannels extends CloseableConnection { public final class NodeChannels extends CloseableConnection {
private final Map<TransportRequestOptions.Type, ConnectionProfile.ConnectionTypeHandle> typeMapping; private final Map<TransportRequestOptions.Type, ConnectionProfile.ConnectionTypeHandle> typeMapping;
private final List<TcpChannel> channels; private final List<TcpChannel> channels;
@ -813,7 +807,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
} }
public void executeHandshake(DiscoveryNode node, TcpChannel channel, ConnectionProfile profile, ActionListener<Version> listener) { public void executeHandshake(DiscoveryNode node, TcpChannel channel, ConnectionProfile profile, ActionListener<Version> listener) {
long requestId = inboundHandler.getResponseHandlers().newRequestId(); long requestId = responseHandlers.newRequestId();
handshaker.sendHandshake(requestId, node, channel, profile.getHandshakeTimeout(), listener); handshaker.sendHandshake(requestId, node, channel, profile.getHandshakeTimeout(), listener);
} }
@ -917,12 +911,12 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
@Override @Override
public final ResponseHandlers getResponseHandlers() { public final ResponseHandlers getResponseHandlers() {
return inboundHandler.getResponseHandlers(); return responseHandlers;
} }
@Override @Override
public final RequestHandlerRegistry<? extends TransportRequest> getRequestHandler(String action) { public final RequestHandlers getRequestHandlers() {
return inboundHandler.getRequestHandler(action); return requestHandlers;
} }
private final class ChannelsConnectedListener implements ActionListener<Void> { private final class ChannelsConnectedListener implements ActionListener<Void> {

View File

@ -22,6 +22,7 @@ package org.elasticsearch.transport;
import org.elasticsearch.Version; import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.collect.MapBuilder;
import org.elasticsearch.common.component.LifecycleComponent; import org.elasticsearch.common.component.LifecycleComponent;
import org.elasticsearch.common.transport.BoundTransportAddress; import org.elasticsearch.common.transport.BoundTransportAddress;
import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.transport.TransportAddress;
@ -32,6 +33,7 @@ import java.io.Closeable;
import java.io.IOException; import java.io.IOException;
import java.net.UnknownHostException; import java.net.UnknownHostException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
@ -42,13 +44,9 @@ public interface Transport extends LifecycleComponent {
/** /**
* Registers a new request handler * Registers a new request handler
*/ */
<Request extends TransportRequest> void registerRequestHandler(RequestHandlerRegistry<Request> reg); default <Request extends TransportRequest> void registerRequestHandler(RequestHandlerRegistry<Request> reg) {
getRequestHandlers().registerHandler(reg);
/** }
* Returns the registered request handler registry for the given action or <code>null</code> if it's not registered
* @param action the action to look up
*/
RequestHandlerRegistry<? extends TransportRequest> getRequestHandler(String action);
void setMessageListener(TransportMessageListener listener); void setMessageListener(TransportMessageListener listener);
@ -87,6 +85,8 @@ public interface Transport extends LifecycleComponent {
ResponseHandlers getResponseHandlers(); ResponseHandlers getResponseHandlers();
RequestHandlers getRequestHandlers();
/** /**
* A unidirectional connection to a {@link DiscoveryNode} * A unidirectional connection to a {@link DiscoveryNode}
*/ */
@ -187,7 +187,7 @@ public interface Transport extends LifecycleComponent {
* Removes and return the {@link ResponseContext} for the given request ID or returns * Removes and return the {@link ResponseContext} for the given request ID or returns
* <code>null</code> if no context is associated with this request ID. * <code>null</code> if no context is associated with this request ID.
*/ */
public ResponseContext remove(long requestId) { public ResponseContext<? extends TransportResponse> remove(long requestId) {
return handlers.remove(requestId); return handlers.remove(requestId);
} }
@ -198,7 +198,7 @@ public interface Transport extends LifecycleComponent {
*/ */
public long add(ResponseContext<? extends TransportResponse> holder) { public long add(ResponseContext<? extends TransportResponse> holder) {
long requestId = newRequestId(); long requestId = newRequestId();
ResponseContext existing = handlers.put(requestId, holder); ResponseContext<? extends TransportResponse> existing = handlers.put(requestId, holder);
assert existing == null : "request ID already in use: " + requestId; assert existing == null : "request ID already in use: " + requestId;
return requestId; return requestId;
} }
@ -214,12 +214,12 @@ public interface Transport extends LifecycleComponent {
/** /**
* Removes and returns all {@link ResponseContext} instances that match the predicate * Removes and returns all {@link ResponseContext} instances that match the predicate
*/ */
public List<ResponseContext<? extends TransportResponse>> prune(Predicate<ResponseContext> predicate) { public List<ResponseContext<? extends TransportResponse>> prune(Predicate<ResponseContext<? extends TransportResponse>> predicate) {
final List<ResponseContext<? extends TransportResponse>> holders = new ArrayList<>(); final List<ResponseContext<? extends TransportResponse>> holders = new ArrayList<>();
for (Map.Entry<Long, ResponseContext<? extends TransportResponse>> entry : handlers.entrySet()) { for (Map.Entry<Long, ResponseContext<? extends TransportResponse>> entry : handlers.entrySet()) {
ResponseContext<? extends TransportResponse> holder = entry.getValue(); ResponseContext<? extends TransportResponse> holder = entry.getValue();
if (predicate.test(holder)) { if (predicate.test(holder)) {
ResponseContext remove = handlers.remove(entry.getKey()); ResponseContext<? extends TransportResponse> remove = handlers.remove(entry.getKey());
if (remove != null) { if (remove != null) {
holders.add(holder); holders.add(holder);
} }
@ -244,4 +244,27 @@ public interface Transport extends LifecycleComponent {
} }
} }
} }
final class RequestHandlers {
private volatile Map<String, RequestHandlerRegistry<? extends TransportRequest>> requestHandlers = Collections.emptyMap();
synchronized <Request extends TransportRequest> void registerHandler(RequestHandlerRegistry<Request> reg) {
if (requestHandlers.containsKey(reg.getAction())) {
throw new IllegalArgumentException("transport handlers for action " + reg.getAction() + " is already registered");
}
requestHandlers = MapBuilder.newMapBuilder(requestHandlers).put(reg.getAction(), reg).immutableMap();
}
// TODO: Only visible for testing. Perhaps move StubbableTransport from
// org.elasticsearch.test.transport to org.elasticsearch.transport
public synchronized <Request extends TransportRequest> void forceRegister(RequestHandlerRegistry<Request> reg) {
requestHandlers = MapBuilder.newMapBuilder(requestHandlers).put(reg.getAction(), reg).immutableMap();
}
@SuppressWarnings("unchecked")
public <T extends TransportRequest> RequestHandlerRegistry<T> getHandler(String action) {
return (RequestHandlerRegistry<T>) requestHandlers.get(action);
}
}
} }

View File

@ -740,7 +740,7 @@ public class TransportService extends AbstractLifecycleComponent implements Repo
} catch (final Exception e) { } catch (final Exception e) {
// usually happen either because we failed to connect to the node // usually happen either because we failed to connect to the node
// or because we failed serializing the message // or because we failed serializing the message
final Transport.ResponseContext contextToNotify = responseHandlers.remove(requestId); final Transport.ResponseContext<? extends TransportResponse> contextToNotify = responseHandlers.remove(requestId);
// If holderToNotify == null then handler has already been taken care of. // If holderToNotify == null then handler has already been taken care of.
if (contextToNotify != null) { if (contextToNotify != null) {
if (timeoutHandler != null) { if (timeoutHandler != null) {
@ -986,7 +986,7 @@ public class TransportService extends AbstractLifecycleComponent implements Repo
} }
public RequestHandlerRegistry<? extends TransportRequest> getRequestHandler(String action) { public RequestHandlerRegistry<? extends TransportRequest> getRequestHandler(String action) {
return transport.getRequestHandler(action); return transport.getRequestHandlers().getHandler(action);
} }
private void checkForTimeout(long requestId) { private void checkForTimeout(long requestId) {
@ -1065,7 +1065,7 @@ public class TransportService extends AbstractLifecycleComponent implements Repo
long timeoutTime = threadPool.relativeTimeInMillis(); long timeoutTime = threadPool.relativeTimeInMillis();
timeoutInfoHandlers.put(requestId, new TimeoutInfoHolder(node, action, sentTime, timeoutTime)); timeoutInfoHandlers.put(requestId, new TimeoutInfoHolder(node, action, sentTime, timeoutTime));
// now that we have the information visible via timeoutInfoHandlers, we try to remove the request id // now that we have the information visible via timeoutInfoHandlers, we try to remove the request id
final Transport.ResponseContext holder = responseHandlers.remove(requestId); final Transport.ResponseContext<? extends TransportResponse> holder = responseHandlers.remove(requestId);
if (holder != null) { if (holder != null) {
assert holder.action().equals(action); assert holder.action().equals(action);
assert holder.connection().getNode().equals(node); assert holder.connection().getNode().equals(node);

View File

@ -28,7 +28,6 @@ import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse;
import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.collect.MapBuilder;
import org.elasticsearch.common.component.Lifecycle; import org.elasticsearch.common.component.Lifecycle;
import org.elasticsearch.common.component.LifecycleListener; import org.elasticsearch.common.component.LifecycleListener;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
@ -37,7 +36,6 @@ import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.transport.CloseableConnection; import org.elasticsearch.transport.CloseableConnection;
import org.elasticsearch.transport.ConnectTransportException; import org.elasticsearch.transport.ConnectTransportException;
import org.elasticsearch.transport.ConnectionProfile; import org.elasticsearch.transport.ConnectionProfile;
import org.elasticsearch.transport.RequestHandlerRegistry;
import org.elasticsearch.transport.Transport; import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportException; import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportMessageListener; import org.elasticsearch.transport.TransportMessageListener;
@ -60,7 +58,7 @@ abstract class FailAndRetryMockTransport<Response extends TransportResponse> imp
private final Random random; private final Random random;
private final ClusterName clusterName; private final ClusterName clusterName;
private volatile Map<String, RequestHandlerRegistry> requestHandlers = Collections.emptyMap(); private final RequestHandlers requestHandlers = new RequestHandlers();
private final Object requestHandlerMutex = new Object(); private final Object requestHandlerMutex = new Object();
private final ResponseHandlers responseHandlers = new ResponseHandlers(); private final ResponseHandlers responseHandlers = new ResponseHandlers();
private TransportMessageListener listener; private TransportMessageListener listener;
@ -205,26 +203,16 @@ abstract class FailAndRetryMockTransport<Response extends TransportResponse> imp
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@Override
public <Request extends TransportRequest> void registerRequestHandler(RequestHandlerRegistry<Request> reg) {
synchronized (requestHandlerMutex) {
if (requestHandlers.containsKey(reg.getAction())) {
throw new IllegalArgumentException("transport handlers for action " + reg.getAction() + " is already registered");
}
requestHandlers = MapBuilder.newMapBuilder(requestHandlers).put(reg.getAction(), reg).immutableMap();
}
}
@Override @Override
public ResponseHandlers getResponseHandlers() { public ResponseHandlers getResponseHandlers() {
return responseHandlers; return responseHandlers;
} }
@Override @Override
public RequestHandlerRegistry getRequestHandler(String action) { public RequestHandlers getRequestHandlers() {
return requestHandlers.get(action); return requestHandlers;
} }
@Override @Override
public void setMessageListener(TransportMessageListener listener) { public void setMessageListener(TransportMessageListener listener) {
this.listener = listener; this.listener = listener;

View File

@ -44,7 +44,6 @@ import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.ConnectTransportException; import org.elasticsearch.transport.ConnectTransportException;
import org.elasticsearch.transport.ConnectionProfile; import org.elasticsearch.transport.ConnectionProfile;
import org.elasticsearch.transport.RequestHandlerRegistry;
import org.elasticsearch.transport.Transport; import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportException; import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportMessageListener; import org.elasticsearch.transport.TransportMessageListener;
@ -505,8 +504,9 @@ public class NodeConnectionsServiceTests extends ESTestCase {
} }
} }
private final class MockTransport implements Transport { private static final class MockTransport implements Transport {
private ResponseHandlers responseHandlers = new ResponseHandlers(); private final ResponseHandlers responseHandlers = new ResponseHandlers();
private final RequestHandlers requestHandlers = new RequestHandlers();
private volatile boolean randomConnectionExceptions = false; private volatile boolean randomConnectionExceptions = false;
private final ThreadPool threadPool; private final ThreadPool threadPool;
@ -514,16 +514,6 @@ public class NodeConnectionsServiceTests extends ESTestCase {
this.threadPool = threadPool; this.threadPool = threadPool;
} }
@Override
public <Request extends TransportRequest> void registerRequestHandler(RequestHandlerRegistry<Request> reg) {
}
@SuppressWarnings("unchecked")
@Override
public RequestHandlerRegistry getRequestHandler(String action) {
return null;
}
@Override @Override
public void setMessageListener(TransportMessageListener listener) { public void setMessageListener(TransportMessageListener listener) {
} }
@ -614,5 +604,10 @@ public class NodeConnectionsServiceTests extends ESTestCase {
public ResponseHandlers getResponseHandlers() { public ResponseHandlers getResponseHandlers() {
return responseHandlers; return responseHandlers;
} }
@Override
public RequestHandlers getRequestHandlers() {
return requestHandlers;
}
} }
} }

View File

@ -229,8 +229,7 @@ public class NodeJoinTests extends ESTestCase {
// clone the node before submitting to simulate an incoming join, which is guaranteed to have a new // clone the node before submitting to simulate an incoming join, which is guaranteed to have a new
// disco node object serialized off the network // disco node object serialized off the network
try { try {
final RequestHandlerRegistry<JoinRequest> joinHandler = (RequestHandlerRegistry<JoinRequest>) final RequestHandlerRegistry<JoinRequest> joinHandler = transport.getRequestHandlers().getHandler(JoinHelper.JOIN_ACTION_NAME);
transport.getRequestHandler(JoinHelper.JOIN_ACTION_NAME);
final ActionListener<TransportResponse> listener = new ActionListener<TransportResponse>() { final ActionListener<TransportResponse> listener = new ActionListener<TransportResponse>() {
@Override @Override
@ -434,8 +433,8 @@ public class NodeJoinTests extends ESTestCase {
} }
private void handleStartJoinFrom(DiscoveryNode node, long term) throws Exception { private void handleStartJoinFrom(DiscoveryNode node, long term) throws Exception {
final RequestHandlerRegistry<StartJoinRequest> startJoinHandler = (RequestHandlerRegistry<StartJoinRequest>) final RequestHandlerRegistry<StartJoinRequest> startJoinHandler = transport.getRequestHandlers()
transport.getRequestHandler(JoinHelper.START_JOIN_ACTION_NAME); .getHandler(JoinHelper.START_JOIN_ACTION_NAME);
startJoinHandler.processMessageReceived(new StartJoinRequest(node, term), new TestTransportChannel( startJoinHandler.processMessageReceived(new StartJoinRequest(node, term), new TestTransportChannel(
new ActionListener<TransportResponse>() { new ActionListener<TransportResponse>() {
@Override @Override
@ -453,9 +452,8 @@ public class NodeJoinTests extends ESTestCase {
} }
private void handleFollowerCheckFrom(DiscoveryNode node, long term) throws Exception { private void handleFollowerCheckFrom(DiscoveryNode node, long term) throws Exception {
final RequestHandlerRegistry<FollowersChecker.FollowerCheckRequest> followerCheckHandler = final RequestHandlerRegistry<FollowersChecker.FollowerCheckRequest> followerCheckHandler = transport.getRequestHandlers()
(RequestHandlerRegistry<FollowersChecker.FollowerCheckRequest>) .getHandler(FollowersChecker.FOLLOWER_CHECK_ACTION_NAME);
transport.getRequestHandler(FollowersChecker.FOLLOWER_CHECK_ACTION_NAME);
final TestTransportChannel channel = new TestTransportChannel(new ActionListener<TransportResponse>() { final TestTransportChannel channel = new TestTransportChannel(new ActionListener<TransportResponse>() {
@Override @Override
public void onResponse(TransportResponse transportResponse) { public void onResponse(TransportResponse transportResponse) {

View File

@ -43,16 +43,17 @@ import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.WriteRequest.RefreshPolicy; import org.elasticsearch.action.support.WriteRequest.RefreshPolicy;
import org.elasticsearch.action.support.replication.ReplicationResponse; import org.elasticsearch.action.support.replication.ReplicationResponse;
import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.NodeConnectionsService;
import org.elasticsearch.cluster.action.shard.ShardStateAction; import org.elasticsearch.cluster.action.shard.ShardStateAction;
import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.routing.IndexShardRoutingTable; import org.elasticsearch.cluster.routing.IndexShardRoutingTable;
import org.elasticsearch.cluster.routing.ShardRoutingState;
import org.elasticsearch.cluster.routing.RecoverySource; import org.elasticsearch.cluster.routing.RecoverySource;
import org.elasticsearch.cluster.routing.RecoverySource.PeerRecoverySource; import org.elasticsearch.cluster.routing.RecoverySource.PeerRecoverySource;
import org.elasticsearch.cluster.routing.RecoverySource.SnapshotRecoverySource; import org.elasticsearch.cluster.routing.RecoverySource.SnapshotRecoverySource;
import org.elasticsearch.cluster.routing.ShardRouting; import org.elasticsearch.cluster.routing.ShardRouting;
import org.elasticsearch.cluster.routing.ShardRoutingState;
import org.elasticsearch.cluster.routing.UnassignedInfo; import org.elasticsearch.cluster.routing.UnassignedInfo;
import org.elasticsearch.cluster.routing.allocation.command.AllocateEmptyPrimaryAllocationCommand; import org.elasticsearch.cluster.routing.allocation.command.AllocateEmptyPrimaryAllocationCommand;
import org.elasticsearch.cluster.routing.allocation.command.MoveAllocationCommand; import org.elasticsearch.cluster.routing.allocation.command.MoveAllocationCommand;
@ -136,6 +137,7 @@ import static org.elasticsearch.action.DocWriteResponse.Result.UPDATED;
import static org.elasticsearch.node.RecoverySettingsChunkSizePlugin.CHUNK_SIZE_SETTING; import static org.elasticsearch.node.RecoverySettingsChunkSizePlugin.CHUNK_SIZE_SETTING;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo;
@ -144,7 +146,6 @@ import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.isOneOf; import static org.hamcrest.Matchers.isOneOf;
import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.hamcrest.Matchers.lessThanOrEqualTo;
import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.empty;
@ClusterScope(scope = Scope.TEST, numDataNodes = 0) @ClusterScope(scope = Scope.TEST, numDataNodes = 0)
public class IndexRecoveryIT extends ESIntegTestCase { public class IndexRecoveryIT extends ESIntegTestCase {
@ -721,6 +722,7 @@ public class IndexRecoveryIT extends ESIntegTestCase {
final Settings nodeSettings = Settings.builder() final Settings nodeSettings = Settings.builder()
.put(RecoverySettings.INDICES_RECOVERY_RETRY_DELAY_NETWORK_SETTING.getKey(), "100ms") .put(RecoverySettings.INDICES_RECOVERY_RETRY_DELAY_NETWORK_SETTING.getKey(), "100ms")
.put(RecoverySettings.INDICES_RECOVERY_INTERNAL_ACTION_TIMEOUT_SETTING.getKey(), "1s") .put(RecoverySettings.INDICES_RECOVERY_INTERNAL_ACTION_TIMEOUT_SETTING.getKey(), "1s")
.put(NodeConnectionsService.CLUSTER_NODE_RECONNECT_INTERVAL_SETTING.getKey(), "1s")
.build(); .build();
// start a master node // start a master node
internalCluster().startNode(nodeSettings); internalCluster().startNode(nodeSettings);
@ -777,12 +779,29 @@ public class IndexRecoveryIT extends ESIntegTestCase {
(MockTransportService) internalCluster().getInstance(TransportService.class, redNodeName); (MockTransportService) internalCluster().getInstance(TransportService.class, redNodeName);
TransportService redTransportService = internalCluster().getInstance(TransportService.class, redNodeName); TransportService redTransportService = internalCluster().getInstance(TransportService.class, redNodeName);
TransportService blueTransportService = internalCluster().getInstance(TransportService.class, blueNodeName); TransportService blueTransportService = internalCluster().getInstance(TransportService.class, blueNodeName);
final CountDownLatch requestBlocked = new CountDownLatch(1); final CountDownLatch requestFailed = new CountDownLatch(1);
blueMockTransportService.addSendBehavior(redTransportService, if (randomBoolean()) {
new RecoveryActionBlocker(dropRequests, recoveryActionToBlock, requestBlocked)); // Fail on the sending side
redMockTransportService.addSendBehavior(blueTransportService, blueMockTransportService.addSendBehavior(redTransportService,
new RecoveryActionBlocker(dropRequests, recoveryActionToBlock, requestBlocked)); new RecoveryActionBlocker(dropRequests, recoveryActionToBlock, requestFailed));
redMockTransportService.addSendBehavior(blueTransportService,
new RecoveryActionBlocker(dropRequests, recoveryActionToBlock, requestFailed));
} else {
// Fail on the receiving side.
blueMockTransportService.addRequestHandlingBehavior(recoveryActionToBlock, (handler, request, channel, task) -> {
logger.info("--> preventing {} response by closing response channel", recoveryActionToBlock);
requestFailed.countDown();
redMockTransportService.disconnectFromNode(blueMockTransportService.getLocalDiscoNode());
handler.messageReceived(request, channel, task);
});
redMockTransportService.addRequestHandlingBehavior(recoveryActionToBlock, (handler, request, channel, task) -> {
logger.info("--> preventing {} response by closing response channel", recoveryActionToBlock);
requestFailed.countDown();
blueMockTransportService.disconnectFromNode(redMockTransportService.getLocalDiscoNode());
handler.messageReceived(request, channel, task);
});
}
logger.info("--> starting recovery from blue to red"); logger.info("--> starting recovery from blue to red");
client().admin().indices().prepareUpdateSettings(indexName).setSettings( client().admin().indices().prepareUpdateSettings(indexName).setSettings(
@ -791,9 +810,9 @@ public class IndexRecoveryIT extends ESIntegTestCase {
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 1) .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 1)
).get(); ).get();
requestBlocked.await(); requestFailed.await();
logger.info("--> stopping to block recovery"); logger.info("--> clearing rules to allow recovery to proceed");
blueMockTransportService.clearAllRules(); blueMockTransportService.clearAllRules();
redMockTransportService.clearAllRules(); redMockTransportService.clearAllRules();
@ -818,12 +837,14 @@ public class IndexRecoveryIT extends ESIntegTestCase {
public void sendRequest(Transport.Connection connection, long requestId, String action, TransportRequest request, public void sendRequest(Transport.Connection connection, long requestId, String action, TransportRequest request,
TransportRequestOptions options) throws IOException { TransportRequestOptions options) throws IOException {
if (recoveryActionToBlock.equals(action) || requestBlocked.getCount() == 0) { if (recoveryActionToBlock.equals(action) || requestBlocked.getCount() == 0) {
logger.info("--> preventing {} request", action);
requestBlocked.countDown(); requestBlocked.countDown();
if (dropRequests) { if (dropRequests) {
logger.info("--> preventing {} request by dropping request", action);
return; return;
} else {
logger.info("--> preventing {} request by throwing ConnectTransportException", action);
throw new ConnectTransportException(connection.getNode(), "DISCONNECT: prevented " + action + " request");
} }
throw new ConnectTransportException(connection.getNode(), "DISCONNECT: prevented " + action + " request");
} }
connection.sendRequest(requestId, action, request, options); connection.sendRequest(requestId, action, request, options);
} }

View File

@ -48,6 +48,8 @@ public class InboundHandlerTests extends ESTestCase {
private final Version version = Version.CURRENT; private final Version version = Version.CURRENT;
private TaskManager taskManager; private TaskManager taskManager;
private Transport.ResponseHandlers responseHandlers;
private Transport.RequestHandlers requestHandlers;
private InboundHandler handler; private InboundHandler handler;
private FakeTcpChannel channel; private FakeTcpChannel channel;
@ -61,7 +63,10 @@ public class InboundHandlerTests extends ESTestCase {
TransportKeepAlive keepAlive = new TransportKeepAlive(threadPool, TcpChannel::sendMessage); TransportKeepAlive keepAlive = new TransportKeepAlive(threadPool, TcpChannel::sendMessage);
OutboundHandler outboundHandler = new OutboundHandler("node", version, new String[0], new StatsTracker(), threadPool, OutboundHandler outboundHandler = new OutboundHandler("node", version, new String[0], new StatsTracker(), threadPool,
BigArrays.NON_RECYCLING_INSTANCE); BigArrays.NON_RECYCLING_INSTANCE);
handler = new InboundHandler(threadPool, outboundHandler, namedWriteableRegistry, handshaker, keepAlive); requestHandlers = new Transport.RequestHandlers();
responseHandlers = new Transport.ResponseHandlers();
handler = new InboundHandler(threadPool, outboundHandler, namedWriteableRegistry, handshaker, keepAlive, requestHandlers,
responseHandlers);
} }
@After @After
@ -74,7 +79,7 @@ public class InboundHandlerTests extends ESTestCase {
AtomicReference<TransportChannel> channelCaptor = new AtomicReference<>(); AtomicReference<TransportChannel> channelCaptor = new AtomicReference<>();
RequestHandlerRegistry<TestRequest> registry = new RequestHandlerRegistry<>("test-request", TestRequest::new, taskManager, RequestHandlerRegistry<TestRequest> registry = new RequestHandlerRegistry<>("test-request", TestRequest::new, taskManager,
(request, channel, task) -> channelCaptor.set(channel), ThreadPool.Names.SAME, false, true); (request, channel, task) -> channelCaptor.set(channel), ThreadPool.Names.SAME, false, true);
handler.registerRequestHandler(registry); requestHandlers.registerHandler(registry);
handler.inboundMessage(channel, new InboundMessage(null, true)); handler.inboundMessage(channel, new InboundMessage(null, true));
if (channel.isServerChannel()) { if (channel.isServerChannel()) {
@ -93,7 +98,7 @@ public class InboundHandlerTests extends ESTestCase {
AtomicReference<Exception> exceptionCaptor = new AtomicReference<>(); AtomicReference<Exception> exceptionCaptor = new AtomicReference<>();
AtomicReference<TransportChannel> channelCaptor = new AtomicReference<>(); AtomicReference<TransportChannel> channelCaptor = new AtomicReference<>();
long requestId = handler.getResponseHandlers().add(new Transport.ResponseContext<>(new TransportResponseHandler<TestResponse>() { long requestId = responseHandlers.add(new Transport.ResponseContext<>(new TransportResponseHandler<TestResponse>() {
@Override @Override
public void handleResponse(TestResponse response) { public void handleResponse(TestResponse response) {
responseCaptor.set(response); responseCaptor.set(response);
@ -119,7 +124,7 @@ public class InboundHandlerTests extends ESTestCase {
channelCaptor.set(channel); channelCaptor.set(channel);
requestCaptor.set(request); requestCaptor.set(request);
}, ThreadPool.Names.SAME, false, true); }, ThreadPool.Names.SAME, false, true);
handler.registerRequestHandler(registry); requestHandlers.registerHandler(registry);
String requestValue = randomAlphaOfLength(10); String requestValue = randomAlphaOfLength(10);
OutboundMessage.Request request = new OutboundMessage.Request(threadPool.getThreadContext(), new String[0], OutboundMessage.Request request = new OutboundMessage.Request(threadPool.getThreadContext(), new String[0],
new TestRequest(requestValue), version, action, requestId, false, false); new TestRequest(requestValue), version, action, requestId, false, false);

View File

@ -168,7 +168,7 @@ public abstract class DisruptableMockTransport extends MockTransport {
protected void onConnectedDuringSend(long requestId, String action, TransportRequest request, protected void onConnectedDuringSend(long requestId, String action, TransportRequest request,
DisruptableMockTransport destinationTransport) { DisruptableMockTransport destinationTransport) {
final RequestHandlerRegistry<TransportRequest> requestHandler = final RequestHandlerRegistry<TransportRequest> requestHandler =
destinationTransport.getRequestHandler(action); destinationTransport.getRequestHandlers().getHandler(action);
final DiscoveryNode destination = destinationTransport.getLocalNode(); final DiscoveryNode destination = destinationTransport.getLocalNode();

View File

@ -0,0 +1,120 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.test.transport;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.component.AbstractLifecycleComponent;
import org.elasticsearch.common.transport.BoundTransportAddress;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.transport.CloseableConnection;
import org.elasticsearch.transport.ConnectionProfile;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportMessageListener;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportRequestOptions;
import org.elasticsearch.transport.TransportStats;
import java.util.Collections;
import java.util.List;
import java.util.Map;
/**
* A transport that does nothing. Normally wrapped by {@link StubbableTransport}.
*/
public class FakeTransport extends AbstractLifecycleComponent implements Transport {
private final RequestHandlers requestHandlers = new RequestHandlers();
private final ResponseHandlers responseHandlers = new ResponseHandlers();
private TransportMessageListener listener;
@Override
public void setMessageListener(TransportMessageListener listener) {
if (this.listener != null) {
throw new IllegalStateException("listener already set");
}
this.listener = listener;
}
@Override
public BoundTransportAddress boundAddress() {
return null;
}
@Override
public Map<String, BoundTransportAddress> profileBoundAddresses() {
return null;
}
@Override
public TransportAddress[] addressesFromString(String address) {
return new TransportAddress[0];
}
@Override
public List<String> getDefaultSeedAddresses() {
return Collections.emptyList();
}
@Override
public void openConnection(DiscoveryNode node, ConnectionProfile profile, ActionListener<Connection> listener) {
listener.onResponse(new CloseableConnection() {
@Override
public DiscoveryNode getNode() {
return node;
}
@Override
public void sendRequest(long requestId, String action, TransportRequest request, TransportRequestOptions options) {
}
});
}
@Override
public TransportStats getStats() {
return null;
}
@Override
public ResponseHandlers getResponseHandlers() {
return responseHandlers;
}
@Override
public RequestHandlers getRequestHandlers() {
return requestHandlers;
}
@Override
protected void doStart() {
}
@Override
protected void doStop() {
}
@Override
protected void doClose() {
}
}

View File

@ -19,31 +19,22 @@
package org.elasticsearch.test.transport; package org.elasticsearch.test.transport;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.ClusterModule; import org.elasticsearch.cluster.ClusterModule;
import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.Randomness; import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.collect.MapBuilder;
import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.component.Lifecycle;
import org.elasticsearch.common.component.LifecycleComponent;
import org.elasticsearch.common.component.LifecycleListener;
import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.BoundTransportAddress; import org.elasticsearch.common.transport.BoundTransportAddress;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.CloseableConnection; import org.elasticsearch.transport.CloseableConnection;
import org.elasticsearch.transport.ClusterConnectionManager; import org.elasticsearch.transport.ClusterConnectionManager;
import org.elasticsearch.transport.ConnectionProfile;
import org.elasticsearch.transport.RemoteTransportException; import org.elasticsearch.transport.RemoteTransportException;
import org.elasticsearch.transport.RequestHandlerRegistry;
import org.elasticsearch.transport.SendRequestTransportException; import org.elasticsearch.transport.SendRequestTransportException;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportException; import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportInterceptor; import org.elasticsearch.transport.TransportInterceptor;
import org.elasticsearch.transport.TransportMessageListener; import org.elasticsearch.transport.TransportMessageListener;
@ -52,12 +43,8 @@ import org.elasticsearch.transport.TransportRequestOptions;
import org.elasticsearch.transport.TransportResponse; import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.transport.TransportResponseHandler; import org.elasticsearch.transport.TransportResponseHandler;
import org.elasticsearch.transport.TransportService; import org.elasticsearch.transport.TransportService;
import org.elasticsearch.transport.TransportStats;
import java.io.IOException; import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ConcurrentMap;
@ -68,11 +55,8 @@ import static org.apache.lucene.util.LuceneTestCase.rarely;
/** /**
* A basic transport implementation that allows to intercept requests that have been sent * A basic transport implementation that allows to intercept requests that have been sent
*/ */
public class MockTransport implements Transport, LifecycleComponent { public class MockTransport extends StubbableTransport {
private volatile Map<String, RequestHandlerRegistry> requestHandlers = Collections.emptyMap();
private final Object requestHandlerMutex = new Object();
private final ResponseHandlers responseHandlers = new ResponseHandlers();
private TransportMessageListener listener; private TransportMessageListener listener;
private ConcurrentMap<Long, Tuple<DiscoveryNode, String>> requests = new ConcurrentHashMap<>(); private ConcurrentMap<Long, Tuple<DiscoveryNode, String>> requests = new ConcurrentHashMap<>();
@ -86,13 +70,18 @@ public class MockTransport implements Transport, LifecycleComponent {
connectionManager); connectionManager);
} }
public MockTransport() {
super(new FakeTransport());
setDefaultConnectBehavior((transport, discoveryNode, profile, listener) -> listener.onResponse(createConnection(discoveryNode)));
}
/** /**
* simulate a response for the given requestId * simulate a response for the given requestId
*/ */
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public <Response extends TransportResponse> void handleResponse(final long requestId, final Response response) { public <Response extends TransportResponse> void handleResponse(final long requestId, final Response response) {
final TransportResponseHandler<Response> transportResponseHandler = final TransportResponseHandler<Response> transportResponseHandler =
(TransportResponseHandler<Response>) responseHandlers.onResponseReceived(requestId, listener); (TransportResponseHandler<Response>) getResponseHandlers().onResponseReceived(requestId, listener);
if (transportResponseHandler != null) { if (transportResponseHandler != null) {
final Response deliveredResponse; final Response deliveredResponse;
try (BytesStreamOutput output = new BytesStreamOutput()) { try (BytesStreamOutput output = new BytesStreamOutput()) {
@ -155,17 +144,12 @@ public class MockTransport implements Transport, LifecycleComponent {
* @param e the failure * @param e the failure
*/ */
public void handleError(final long requestId, final TransportException e) { public void handleError(final long requestId, final TransportException e) {
final TransportResponseHandler transportResponseHandler = responseHandlers.onResponseReceived(requestId, listener); final TransportResponseHandler transportResponseHandler = getResponseHandlers().onResponseReceived(requestId, listener);
if (transportResponseHandler != null) { if (transportResponseHandler != null) {
transportResponseHandler.handleException(e); transportResponseHandler.handleException(e);
} }
} }
@Override
public void openConnection(DiscoveryNode node, ConnectionProfile profile, ActionListener<Connection> listener) {
listener.onResponse(createConnection(node));
}
public Connection createConnection(DiscoveryNode node) { public Connection createConnection(DiscoveryNode node) {
return new CloseableConnection() { return new CloseableConnection() {
@Override @Override
@ -185,83 +169,13 @@ public class MockTransport implements Transport, LifecycleComponent {
protected void onSendRequest(long requestId, String action, TransportRequest request, DiscoveryNode node) { protected void onSendRequest(long requestId, String action, TransportRequest request, DiscoveryNode node) {
} }
@Override
public TransportStats getStats() {
throw new UnsupportedOperationException();
}
@Override
public BoundTransportAddress boundAddress() {
return null;
}
@Override
public Map<String, BoundTransportAddress> profileBoundAddresses() {
return null;
}
@Override
public TransportAddress[] addressesFromString(String address) {
return new TransportAddress[0];
}
@Override
public Lifecycle.State lifecycleState() {
return null;
}
@Override
public void addLifecycleListener(LifecycleListener listener) {
}
@Override
public void removeLifecycleListener(LifecycleListener listener) {
}
@Override
public void start() {
}
@Override
public void stop() {
}
@Override
public void close() {
}
@Override
public List<String> getDefaultSeedAddresses() {
return Collections.emptyList();
}
@Override
public <Request extends TransportRequest> void registerRequestHandler(RequestHandlerRegistry<Request> reg) {
synchronized (requestHandlerMutex) {
if (requestHandlers.containsKey(reg.getAction())) {
throw new IllegalArgumentException("transport handlers for action " + reg.getAction() + " is already registered");
}
requestHandlers = MapBuilder.newMapBuilder(requestHandlers).put(reg.getAction(), reg).immutableMap();
}
}
@Override
public ResponseHandlers getResponseHandlers() {
return responseHandlers;
}
@SuppressWarnings("unchecked")
@Override
public RequestHandlerRegistry<TransportRequest> getRequestHandler(String action) {
return requestHandlers.get(action);
}
@Override @Override
public void setMessageListener(TransportMessageListener listener) { public void setMessageListener(TransportMessageListener listener) {
if (this.listener != null) { if (this.listener != null) {
throw new IllegalStateException("listener already set"); throw new IllegalStateException("listener already set");
} }
this.listener = listener; this.listener = listener;
super.setMessageListener(listener);
} }
protected NamedWriteableRegistry writeableRegistry() { protected NamedWriteableRegistry writeableRegistry() {

View File

@ -415,6 +415,15 @@ public final class MockTransportService extends TransportService {
}); });
} }
/**
* Adds a new handling behavior that is used when the defined request is received.
*
*/
public <R extends TransportRequest> void addRequestHandlingBehavior(String actionName,
StubbableTransport.RequestHandlingBehavior<R> handlingBehavior) {
transport().addRequestHandlingBehavior(actionName, handlingBehavior);
}
/** /**
* Adds a new send behavior that is used for communication with the given delegate service. * Adds a new send behavior that is used for communication with the given delegate service.
* *

View File

@ -26,12 +26,15 @@ import org.elasticsearch.common.component.Lifecycle;
import org.elasticsearch.common.component.LifecycleListener; import org.elasticsearch.common.component.LifecycleListener;
import org.elasticsearch.common.transport.BoundTransportAddress; import org.elasticsearch.common.transport.BoundTransportAddress;
import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.ConnectionProfile; import org.elasticsearch.transport.ConnectionProfile;
import org.elasticsearch.transport.RequestHandlerRegistry; import org.elasticsearch.transport.RequestHandlerRegistry;
import org.elasticsearch.transport.Transport; import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportException; import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportMessageListener; import org.elasticsearch.transport.TransportMessageListener;
import org.elasticsearch.transport.TransportRequest; import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportRequestHandler;
import org.elasticsearch.transport.TransportRequestOptions; import org.elasticsearch.transport.TransportRequestOptions;
import org.elasticsearch.transport.TransportStats; import org.elasticsearch.transport.TransportStats;
@ -41,10 +44,11 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
public final class StubbableTransport implements Transport { public class StubbableTransport implements Transport {
private final ConcurrentHashMap<TransportAddress, SendRequestBehavior> sendBehaviors = new ConcurrentHashMap<>(); private final ConcurrentHashMap<TransportAddress, SendRequestBehavior> sendBehaviors = new ConcurrentHashMap<>();
private final ConcurrentHashMap<TransportAddress, OpenConnectionBehavior> connectBehaviors = new ConcurrentHashMap<>(); private final ConcurrentHashMap<TransportAddress, OpenConnectionBehavior> connectBehaviors = new ConcurrentHashMap<>();
private final ConcurrentHashMap<String, RequestHandlerRegistry<?>> replacedRequestRegistries = new ConcurrentHashMap<>();
private volatile SendRequestBehavior defaultSendRequest = null; private volatile SendRequestBehavior defaultSendRequest = null;
private volatile OpenConnectionBehavior defaultConnectBehavior = null; private volatile OpenConnectionBehavior defaultConnectBehavior = null;
private final Transport delegate; private final Transport delegate;
@ -74,11 +78,28 @@ public final class StubbableTransport implements Transport {
return connectBehaviors.put(transportAddress, connectBehavior) == null; return connectBehaviors.put(transportAddress, connectBehavior) == null;
} }
<Request extends TransportRequest> void addRequestHandlingBehavior(String actionName, RequestHandlingBehavior<Request> behavior) {
final RequestHandlers requestHandlers = delegate.getRequestHandlers();
final RequestHandlerRegistry<Request> realRegistry = requestHandlers.getHandler(actionName);
if (realRegistry == null) {
throw new IllegalStateException("Cannot find registered action for: " + actionName);
}
replacedRequestRegistries.put(actionName, realRegistry);
final TransportRequestHandler<Request> realHandler = realRegistry.getHandler();
final RequestHandlerRegistry<Request> newRegistry = RequestHandlerRegistry.replaceHandler(realRegistry, (request, channel, task) ->
behavior.messageReceived(realHandler, request, channel, task));
requestHandlers.forceRegister(newRegistry);
}
void clearBehaviors() { void clearBehaviors() {
this.defaultSendRequest = null; this.defaultSendRequest = null;
sendBehaviors.clear(); sendBehaviors.clear();
this.defaultConnectBehavior = null; this.defaultConnectBehavior = null;
connectBehaviors.clear(); connectBehaviors.clear();
for (Map.Entry<String, RequestHandlerRegistry<?>> entry : replacedRequestRegistries.entrySet()) {
getRequestHandlers().forceRegister(entry.getValue());
}
replacedRequestRegistries.clear();
} }
void clearBehavior(TransportAddress transportAddress) { void clearBehavior(TransportAddress transportAddress) {
@ -101,16 +122,6 @@ public final class StubbableTransport implements Transport {
delegate.setMessageListener(listener); delegate.setMessageListener(listener);
} }
@Override
public <Request extends TransportRequest> void registerRequestHandler(RequestHandlerRegistry<Request> reg) {
delegate.registerRequestHandler(reg);
}
@Override
public RequestHandlerRegistry getRequestHandler(String action) {
return delegate.getRequestHandler(action);
}
@Override @Override
public BoundTransportAddress boundAddress() { public BoundTransportAddress boundAddress() {
return delegate.boundAddress(); return delegate.boundAddress();
@ -152,6 +163,11 @@ public final class StubbableTransport implements Transport {
return delegate.getResponseHandlers(); return delegate.getResponseHandlers();
} }
@Override
public RequestHandlers getRequestHandlers() {
return delegate.getRequestHandlers();
}
@Override @Override
public Lifecycle.State lifecycleState() { public Lifecycle.State lifecycleState() {
return delegate.lifecycleState(); return delegate.lifecycleState();
@ -257,7 +273,15 @@ public final class StubbableTransport implements Transport {
void sendRequest(Connection connection, long requestId, String action, TransportRequest request, void sendRequest(Connection connection, long requestId, String action, TransportRequest request,
TransportRequestOptions options) throws IOException; TransportRequestOptions options) throws IOException;
default void clearCallback() { default void clearCallback() {}
} }
@FunctionalInterface
public interface RequestHandlingBehavior<Request extends TransportRequest> {
void messageReceived(TransportRequestHandler<Request> handler, Request request, TransportChannel channel, Task task)
throws Exception;
default void clearCallback() {}
} }
} }

View File

@ -54,7 +54,6 @@ import org.elasticsearch.nio.Page;
import org.elasticsearch.nio.ServerChannelContext; import org.elasticsearch.nio.ServerChannelContext;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.ConnectionProfile; import org.elasticsearch.transport.ConnectionProfile;
import org.elasticsearch.transport.InboundHandler;
import org.elasticsearch.transport.InboundPipeline; import org.elasticsearch.transport.InboundPipeline;
import org.elasticsearch.transport.StatsTracker; import org.elasticsearch.transport.StatsTracker;
import org.elasticsearch.transport.TcpChannel; import org.elasticsearch.transport.TcpChannel;
@ -278,11 +277,11 @@ public class MockNioTransport extends TcpTransport {
this.channel = channel; this.channel = channel;
final ThreadPool threadPool = transport.getThreadPool(); final ThreadPool threadPool = transport.getThreadPool();
final Supplier<CircuitBreaker> breaker = transport.getInflightBreaker(); final Supplier<CircuitBreaker> breaker = transport.getInflightBreaker();
final InboundHandler inboundHandler = transport.getInboundHandler(); final RequestHandlers requestHandlers = transport.getRequestHandlers();
final Version version = transport.getVersion(); final Version version = transport.getVersion();
final StatsTracker statsTracker = transport.getStatsTracker(); final StatsTracker statsTracker = transport.getStatsTracker();
this.pipeline = new InboundPipeline(version, statsTracker, recycler, threadPool::relativeTimeInMillis, breaker, this.pipeline = new InboundPipeline(version, statsTracker, recycler, threadPool::relativeTimeInMillis, breaker,
inboundHandler::getRequestHandler, transport::inboundMessage); requestHandlers::getHandler, transport::inboundMessage);
} }
@Override @Override

View File

@ -22,7 +22,8 @@ public class SecurityServerTransportServiceTests extends SecurityIntegTestCase {
public void testSecurityServerTransportServiceWrapsAllHandlers() { public void testSecurityServerTransportServiceWrapsAllHandlers() {
for (TransportService transportService : internalCluster().getInstances(TransportService.class)) { for (TransportService transportService : internalCluster().getInstances(TransportService.class)) {
RequestHandlerRegistry handler = transportService.transport.getRequestHandler(TransportService.HANDSHAKE_ACTION_NAME); RequestHandlerRegistry handler = transportService.transport.getRequestHandlers()
.getHandler(TransportService.HANDSHAKE_ACTION_NAME);
assertEquals( assertEquals(
"handler not wrapped by " + SecurityServerTransportInterceptor.ProfileSecuredRequestHandler.class + "handler not wrapped by " + SecurityServerTransportInterceptor.ProfileSecuredRequestHandler.class +
"; do all the handler registration methods have overrides?", "; do all the handler registration methods have overrides?",