MockTcpTransport to connect asynchronously (#28203)

The method `initiateChannel` on `TcpTransport` is explicit in that
channels can be connect asynchronously. All production implementations
do connect asynchronously. Only the blocking `MockTcpTransport`
connects in a synchronous manner. This avoids testing some of the
blocking code in `TcpTransport` that waits on connections to complete.
Additionally, it requires a more extensive method signature than
required for other transports.

This commit modifies the `MockTcpTransport` to make these connections
asynchronously on a different thread. Additionally, it simplifies that
`initiateChannel` method signature.
This commit is contained in:
Tim Brooks 2018-01-15 10:20:30 -07:00 committed by GitHub
parent 190f1e1fb3
commit ee7eac8dc1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 46 additions and 56 deletions

View File

@ -40,7 +40,6 @@ import org.apache.logging.log4j.util.Supplier;
import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.SuppressForbidden; import org.elasticsearch.common.SuppressForbidden;
import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
@ -51,12 +50,10 @@ import org.elasticsearch.common.settings.Setting.Property;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeUnit; import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TcpChannel;
import org.elasticsearch.transport.TcpTransport; import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.transport.TransportRequestOptions; import org.elasticsearch.transport.TransportRequestOptions;
@ -239,9 +236,8 @@ public class Netty4Transport extends TcpTransport {
} }
@Override @Override
protected NettyTcpChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Void> listener) protected NettyTcpChannel initiateChannel(InetSocketAddress address, ActionListener<Void> listener) throws IOException {
throws IOException { ChannelFuture channelFuture = bootstrap.connect(address);
ChannelFuture channelFuture = bootstrap.connect(node.getAddress().address());
Channel channel = channelFuture.channel(); Channel channel = channelFuture.channel();
if (channel == null) { if (channel == null) {
Netty4Utils.maybeDie(channelFuture.cause()); Netty4Utils.maybeDie(channelFuture.cause());

View File

@ -21,14 +21,12 @@ package org.elasticsearch.transport.nio;
import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.recycler.Recycler; import org.elasticsearch.common.recycler.Recycler;
import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.PageCacheRecycler; import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.EsExecutors;
@ -93,9 +91,8 @@ public class NioTransport extends TcpTransport {
} }
@Override @Override
protected TcpNioSocketChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Void> connectListener) protected TcpNioSocketChannel initiateChannel(InetSocketAddress address, ActionListener<Void> connectListener) throws IOException {
throws IOException { TcpNioSocketChannel channel = nioGroup.openChannel(address, clientChannelFactory);
TcpNioSocketChannel channel = nioGroup.openChannel(node.getAddress().address(), clientChannelFactory);
channel.addConnectListener(ActionListener.toBiConsumer(connectListener)); channel.addConnectListener(ActionListener.toBiConsumer(connectListener));
return channel; return channel;
} }

View File

@ -604,7 +604,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
try { try {
PlainActionFuture<Void> connectFuture = PlainActionFuture.newFuture(); PlainActionFuture<Void> connectFuture = PlainActionFuture.newFuture();
connectionFutures.add(connectFuture); connectionFutures.add(connectFuture);
TcpChannel channel = initiateChannel(node, connectionProfile.getConnectTimeout(), connectFuture); TcpChannel channel = initiateChannel(node.getAddress().address(), connectFuture);
logger.trace(() -> new ParameterizedMessage("Tcp transport client channel opened: {}", channel)); logger.trace(() -> new ParameterizedMessage("Tcp transport client channel opened: {}", channel));
channels.add(channel); channels.add(channel);
} catch (Exception e) { } catch (Exception e) {
@ -1057,17 +1057,14 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
protected abstract TcpChannel bind(String name, InetSocketAddress address) throws IOException; protected abstract TcpChannel bind(String name, InetSocketAddress address) throws IOException;
/** /**
* Initiate a single tcp socket channel to a node. Implementations do not have to observe the connectTimeout. * Initiate a single tcp socket channel.
* It is provided for synchronous connection implementations.
* *
* @param node the node * @param address address for the initiated connection
* @param connectTimeout the connection timeout * @param connectListener listener to be called when connection complete
* @param connectListener listener to be called when connection complete
* @return the pending connection * @return the pending connection
* @throws IOException if an I/O exception occurs while opening the channel * @throws IOException if an I/O exception occurs while opening the channel
*/ */
protected abstract TcpChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Void> connectListener) protected abstract TcpChannel initiateChannel(InetSocketAddress address, ActionListener<Void> connectListener) throws IOException;
throws IOException;
/** /**
* Called to tear down internal resources * Called to tear down internal resources

View File

@ -22,7 +22,6 @@ package org.elasticsearch.transport;
import org.elasticsearch.Version; import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.compress.CompressorFactory; import org.elasticsearch.common.compress.CompressorFactory;
import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.BytesStreamOutput;
@ -41,7 +40,6 @@ import org.elasticsearch.threadpool.ThreadPool;
import java.io.IOException; import java.io.IOException;
import java.io.StreamCorruptedException; import java.io.StreamCorruptedException;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
@ -49,7 +47,6 @@ import java.util.concurrent.atomic.AtomicReference;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.core.IsInstanceOf.instanceOf; import static org.hamcrest.core.IsInstanceOf.instanceOf;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
/** Unit tests for {@link TcpTransport} */ /** Unit tests for {@link TcpTransport} */
public class TcpTransportTests extends ESTestCase { public class TcpTransportTests extends ESTestCase {
@ -193,8 +190,7 @@ public class TcpTransportTests extends ESTestCase {
} }
@Override @Override
protected FakeChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Void> connectListener) protected FakeChannel initiateChannel(InetSocketAddress address, ActionListener<Void> connectListener) throws IOException {
throws IOException {
return new FakeChannel(messageCaptor); return new FakeChannel(messageCaptor);
} }

View File

@ -21,7 +21,6 @@ package org.elasticsearch.transport;
import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.IOUtils;
import org.elasticsearch.Version; import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.InputStreamStreamInput; import org.elasticsearch.common.io.stream.InputStreamStreamInput;
@ -30,7 +29,6 @@ import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.CancellableThreads; import org.elasticsearch.common.util.CancellableThreads;
import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.AbstractRunnable;
@ -49,7 +47,6 @@ import java.net.InetSocketAddress;
import java.net.ServerSocket; import java.net.ServerSocket;
import java.net.Socket; import java.net.Socket;
import java.net.SocketException; import java.net.SocketException;
import java.net.SocketTimeoutException;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import java.util.Set; import java.util.Set;
@ -61,7 +58,6 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
/** /**
* This is a socket based blocking TcpTransport implementation that is used for tests * This is a socket based blocking TcpTransport implementation that is used for tests
@ -164,28 +160,32 @@ public class MockTcpTransport extends TcpTransport {
} }
@Override @Override
protected MockChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Void> connectListener) protected MockChannel initiateChannel(InetSocketAddress address, ActionListener<Void> connectListener) throws IOException {
throws IOException {
InetSocketAddress address = node.getAddress().address();
final MockSocket socket = new MockSocket(); final MockSocket socket = new MockSocket();
final MockChannel channel = new MockChannel(socket, address, "none");
boolean success = false; boolean success = false;
try { try {
configureSocket(socket); configureSocket(socket);
try {
socket.connect(address, Math.toIntExact(connectTimeout.millis()));
} catch (SocketTimeoutException ex) {
throw new ConnectTransportException(node, "connect_timeout[" + connectTimeout + "]", ex);
}
MockChannel channel = new MockChannel(socket, address, "none", (c) -> {});
channel.loopRead(executor);
success = true; success = true;
connectListener.onResponse(null);
return channel;
} finally { } finally {
if (success == false) { if (success == false) {
IOUtils.close(socket); IOUtils.close(socket);
} }
} }
executor.submit(() -> {
try {
socket.connect(address);
channel.loopRead(executor);
connectListener.onResponse(null);
} catch (Exception ex) {
connectListener.onFailure(ex);
}
});
return channel;
} }
@Override @Override
@ -218,7 +218,6 @@ public class MockTcpTransport extends TcpTransport {
private final Socket activeChannel; private final Socket activeChannel;
private final String profile; private final String profile;
private final CancellableThreads cancellableThreads = new CancellableThreads(); private final CancellableThreads cancellableThreads = new CancellableThreads();
private final Closeable onClose;
private final CompletableFuture<Void> closeFuture = new CompletableFuture<>(); private final CompletableFuture<Void> closeFuture = new CompletableFuture<>();
/** /**
@ -227,14 +226,12 @@ public class MockTcpTransport extends TcpTransport {
* @param socket The client socket. Mut not be null. * @param socket The client socket. Mut not be null.
* @param localAddress Address associated with the corresponding local server socket. Must not be null. * @param localAddress Address associated with the corresponding local server socket. Must not be null.
* @param profile The associated profile name. * @param profile The associated profile name.
* @param onClose Callback to execute when this channel is closed.
*/ */
public MockChannel(Socket socket, InetSocketAddress localAddress, String profile, Consumer<MockChannel> onClose) { public MockChannel(Socket socket, InetSocketAddress localAddress, String profile) {
this.localAddress = localAddress; this.localAddress = localAddress;
this.activeChannel = socket; this.activeChannel = socket;
this.serverSocket = null; this.serverSocket = null;
this.profile = profile; this.profile = profile;
this.onClose = () -> onClose.accept(this);
synchronized (openChannels) { synchronized (openChannels) {
openChannels.add(this); openChannels.add(this);
} }
@ -246,12 +243,11 @@ public class MockTcpTransport extends TcpTransport {
* @param serverSocket The associated server socket. Must not be null. * @param serverSocket The associated server socket. Must not be null.
* @param profile The associated profile name. * @param profile The associated profile name.
*/ */
public MockChannel(ServerSocket serverSocket, String profile) { MockChannel(ServerSocket serverSocket, String profile) {
this.localAddress = (InetSocketAddress) serverSocket.getLocalSocketAddress(); this.localAddress = (InetSocketAddress) serverSocket.getLocalSocketAddress();
this.serverSocket = serverSocket; this.serverSocket = serverSocket;
this.profile = profile; this.profile = profile;
this.activeChannel = null; this.activeChannel = null;
this.onClose = null;
synchronized (openChannels) { synchronized (openChannels) {
openChannels.add(this); openChannels.add(this);
} }
@ -266,8 +262,19 @@ public class MockTcpTransport extends TcpTransport {
synchronized (this) { synchronized (this) {
if (isOpen.get()) { if (isOpen.get()) {
incomingChannel = new MockChannel(incomingSocket, incomingChannel = new MockChannel(incomingSocket,
new InetSocketAddress(incomingSocket.getLocalAddress(), incomingSocket.getPort()), profile, new InetSocketAddress(incomingSocket.getLocalAddress(), incomingSocket.getPort()), profile);
workerChannels::remove); MockChannel finalIncomingChannel = incomingChannel;
incomingChannel.addCloseListener(new ActionListener<Void>() {
@Override
public void onResponse(Void aVoid) {
workerChannels.remove(finalIncomingChannel);
}
@Override
public void onFailure(Exception e) {
workerChannels.remove(finalIncomingChannel);
}
});
serverAcceptedChannel(incomingChannel); serverAcceptedChannel(incomingChannel);
//establish a happens-before edge between closing and accepting a new connection //establish a happens-before edge between closing and accepting a new connection
workerChannels.add(incomingChannel); workerChannels.add(incomingChannel);
@ -287,7 +294,7 @@ public class MockTcpTransport extends TcpTransport {
} }
} }
public void loopRead(Executor executor) { void loopRead(Executor executor) {
executor.execute(new AbstractRunnable() { executor.execute(new AbstractRunnable() {
@Override @Override
public void onFailure(Exception e) { public void onFailure(Exception e) {
@ -312,7 +319,7 @@ public class MockTcpTransport extends TcpTransport {
}); });
} }
public synchronized void close0() throws IOException { synchronized void close0() throws IOException {
// establish a happens-before edge between closing and accepting a new connection // establish a happens-before edge between closing and accepting a new connection
// we have to sync this entire block to ensure that our openChannels checks work correctly. // we have to sync this entire block to ensure that our openChannels checks work correctly.
// The close block below will close all worker channels but if one of the worker channels runs into an exception // The close block below will close all worker channels but if one of the worker channels runs into an exception
@ -325,7 +332,7 @@ public class MockTcpTransport extends TcpTransport {
removedChannel = openChannels.remove(this); removedChannel = openChannels.remove(this);
} }
IOUtils.close(serverSocket, activeChannel, () -> IOUtils.close(workerChannels), IOUtils.close(serverSocket, activeChannel, () -> IOUtils.close(workerChannels),
() -> cancellableThreads.cancel("channel closed"), onClose); () -> cancellableThreads.cancel("channel closed"));
assert removedChannel: "Channel was not removed or removed twice?"; assert removedChannel: "Channel was not removed or removed twice?";
} }
} }

View File

@ -21,13 +21,11 @@ package org.elasticsearch.transport.nio;
import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.recycler.Recycler; import org.elasticsearch.common.recycler.Recycler;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.PageCacheRecycler; import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.indices.breaker.CircuitBreakerService;
@ -83,9 +81,8 @@ public class MockNioTransport extends TcpTransport {
} }
@Override @Override
protected MockSocketChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Void> connectListener) protected MockSocketChannel initiateChannel(InetSocketAddress address, ActionListener<Void> connectListener) throws IOException {
throws IOException { MockSocketChannel channel = nioGroup.openChannel(address, clientChannelFactory);
MockSocketChannel channel = nioGroup.openChannel(node.getAddress().address(), clientChannelFactory);
channel.addConnectListener(ActionListener.toBiConsumer(connectListener)); channel.addConnectListener(ActionListener.toBiConsumer(connectListener));
return channel; return channel;
} }