`MockTcpTransport` to connect asynchronously (#28203)

The method `initiateChannel` on `TcpTransport` is explicit in that
channels can be connect asynchronously. All production implementations
do connect asynchronously. Only the blocking `MockTcpTransport`
connects in a synchronous manner. This avoids testing some of the
blocking code in `TcpTransport` that waits on connections to complete.
Additionally, it requires a more extensive method signature than
required for other transports.

This commit modifies the `MockTcpTransport` to make these connections
asynchronously on a different thread. Additionally, it simplifies that
`initiateChannel` method signature.
This commit is contained in:
Tim Brooks 2018-01-15 10:20:30 -07:00 committed by GitHub
parent 190f1e1fb3
commit ee7eac8dc1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 46 additions and 56 deletions

View File

@ -40,7 +40,6 @@ import org.apache.logging.log4j.util.Supplier;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.SuppressForbidden;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
@ -51,12 +50,10 @@ import org.elasticsearch.common.settings.Setting.Property;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TcpChannel;
import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.transport.TransportRequestOptions;
@ -239,9 +236,8 @@ public class Netty4Transport extends TcpTransport {
}
@Override
protected NettyTcpChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Void> listener)
throws IOException {
ChannelFuture channelFuture = bootstrap.connect(node.getAddress().address());
protected NettyTcpChannel initiateChannel(InetSocketAddress address, ActionListener<Void> listener) throws IOException {
ChannelFuture channelFuture = bootstrap.connect(address);
Channel channel = channelFuture.channel();
if (channel == null) {
Netty4Utils.maybeDie(channelFuture.cause());

View File

@ -21,14 +21,12 @@ package org.elasticsearch.transport.nio;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.recycler.Recycler;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.common.util.concurrent.EsExecutors;
@ -93,9 +91,8 @@ public class NioTransport extends TcpTransport {
}
@Override
protected TcpNioSocketChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Void> connectListener)
throws IOException {
TcpNioSocketChannel channel = nioGroup.openChannel(node.getAddress().address(), clientChannelFactory);
protected TcpNioSocketChannel initiateChannel(InetSocketAddress address, ActionListener<Void> connectListener) throws IOException {
TcpNioSocketChannel channel = nioGroup.openChannel(address, clientChannelFactory);
channel.addConnectListener(ActionListener.toBiConsumer(connectListener));
return channel;
}

View File

@ -604,7 +604,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
try {
PlainActionFuture<Void> connectFuture = PlainActionFuture.newFuture();
connectionFutures.add(connectFuture);
TcpChannel channel = initiateChannel(node, connectionProfile.getConnectTimeout(), connectFuture);
TcpChannel channel = initiateChannel(node.getAddress().address(), connectFuture);
logger.trace(() -> new ParameterizedMessage("Tcp transport client channel opened: {}", channel));
channels.add(channel);
} catch (Exception e) {
@ -1057,17 +1057,14 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
protected abstract TcpChannel bind(String name, InetSocketAddress address) throws IOException;
/**
* Initiate a single tcp socket channel to a node. Implementations do not have to observe the connectTimeout.
* It is provided for synchronous connection implementations.
* Initiate a single tcp socket channel.
*
* @param node the node
* @param connectTimeout the connection timeout
* @param connectListener listener to be called when connection complete
* @param address address for the initiated connection
* @param connectListener listener to be called when connection complete
* @return the pending connection
* @throws IOException if an I/O exception occurs while opening the channel
*/
protected abstract TcpChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Void> connectListener)
throws IOException;
protected abstract TcpChannel initiateChannel(InetSocketAddress address, ActionListener<Void> connectListener) throws IOException;
/**
* Called to tear down internal resources

View File

@ -22,7 +22,6 @@ package org.elasticsearch.transport;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.compress.CompressorFactory;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
@ -41,7 +40,6 @@ import org.elasticsearch.threadpool.ThreadPool;
import java.io.IOException;
import java.io.StreamCorruptedException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
@ -49,7 +47,6 @@ import java.util.concurrent.atomic.AtomicReference;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.core.IsInstanceOf.instanceOf;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
/** Unit tests for {@link TcpTransport} */
public class TcpTransportTests extends ESTestCase {
@ -193,8 +190,7 @@ public class TcpTransportTests extends ESTestCase {
}
@Override
protected FakeChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Void> connectListener)
throws IOException {
protected FakeChannel initiateChannel(InetSocketAddress address, ActionListener<Void> connectListener) throws IOException {
return new FakeChannel(messageCaptor);
}

View File

@ -21,7 +21,6 @@ package org.elasticsearch.transport;
import org.apache.lucene.util.IOUtils;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.InputStreamStreamInput;
@ -30,7 +29,6 @@ import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.CancellableThreads;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
@ -49,7 +47,6 @@ import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketException;
import java.net.SocketTimeoutException;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
@ -61,7 +58,6 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
/**
* This is a socket based blocking TcpTransport implementation that is used for tests
@ -164,28 +160,32 @@ public class MockTcpTransport extends TcpTransport {
}
@Override
protected MockChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Void> connectListener)
throws IOException {
InetSocketAddress address = node.getAddress().address();
protected MockChannel initiateChannel(InetSocketAddress address, ActionListener<Void> connectListener) throws IOException {
final MockSocket socket = new MockSocket();
final MockChannel channel = new MockChannel(socket, address, "none");
boolean success = false;
try {
configureSocket(socket);
try {
socket.connect(address, Math.toIntExact(connectTimeout.millis()));
} catch (SocketTimeoutException ex) {
throw new ConnectTransportException(node, "connect_timeout[" + connectTimeout + "]", ex);
}
MockChannel channel = new MockChannel(socket, address, "none", (c) -> {});
channel.loopRead(executor);
success = true;
connectListener.onResponse(null);
return channel;
} finally {
if (success == false) {
IOUtils.close(socket);
}
}
executor.submit(() -> {
try {
socket.connect(address);
channel.loopRead(executor);
connectListener.onResponse(null);
} catch (Exception ex) {
connectListener.onFailure(ex);
}
});
return channel;
}
@Override
@ -218,7 +218,6 @@ public class MockTcpTransport extends TcpTransport {
private final Socket activeChannel;
private final String profile;
private final CancellableThreads cancellableThreads = new CancellableThreads();
private final Closeable onClose;
private final CompletableFuture<Void> closeFuture = new CompletableFuture<>();
/**
@ -227,14 +226,12 @@ public class MockTcpTransport extends TcpTransport {
* @param socket The client socket. Mut not be null.
* @param localAddress Address associated with the corresponding local server socket. Must not be null.
* @param profile The associated profile name.
* @param onClose Callback to execute when this channel is closed.
*/
public MockChannel(Socket socket, InetSocketAddress localAddress, String profile, Consumer<MockChannel> onClose) {
public MockChannel(Socket socket, InetSocketAddress localAddress, String profile) {
this.localAddress = localAddress;
this.activeChannel = socket;
this.serverSocket = null;
this.profile = profile;
this.onClose = () -> onClose.accept(this);
synchronized (openChannels) {
openChannels.add(this);
}
@ -246,12 +243,11 @@ public class MockTcpTransport extends TcpTransport {
* @param serverSocket The associated server socket. Must not be null.
* @param profile The associated profile name.
*/
public MockChannel(ServerSocket serverSocket, String profile) {
MockChannel(ServerSocket serverSocket, String profile) {
this.localAddress = (InetSocketAddress) serverSocket.getLocalSocketAddress();
this.serverSocket = serverSocket;
this.profile = profile;
this.activeChannel = null;
this.onClose = null;
synchronized (openChannels) {
openChannels.add(this);
}
@ -266,8 +262,19 @@ public class MockTcpTransport extends TcpTransport {
synchronized (this) {
if (isOpen.get()) {
incomingChannel = new MockChannel(incomingSocket,
new InetSocketAddress(incomingSocket.getLocalAddress(), incomingSocket.getPort()), profile,
workerChannels::remove);
new InetSocketAddress(incomingSocket.getLocalAddress(), incomingSocket.getPort()), profile);
MockChannel finalIncomingChannel = incomingChannel;
incomingChannel.addCloseListener(new ActionListener<Void>() {
@Override
public void onResponse(Void aVoid) {
workerChannels.remove(finalIncomingChannel);
}
@Override
public void onFailure(Exception e) {
workerChannels.remove(finalIncomingChannel);
}
});
serverAcceptedChannel(incomingChannel);
//establish a happens-before edge between closing and accepting a new connection
workerChannels.add(incomingChannel);
@ -287,7 +294,7 @@ public class MockTcpTransport extends TcpTransport {
}
}
public void loopRead(Executor executor) {
void loopRead(Executor executor) {
executor.execute(new AbstractRunnable() {
@Override
public void onFailure(Exception e) {
@ -312,7 +319,7 @@ public class MockTcpTransport extends TcpTransport {
});
}
public synchronized void close0() throws IOException {
synchronized void close0() throws IOException {
// establish a happens-before edge between closing and accepting a new connection
// we have to sync this entire block to ensure that our openChannels checks work correctly.
// The close block below will close all worker channels but if one of the worker channels runs into an exception
@ -325,7 +332,7 @@ public class MockTcpTransport extends TcpTransport {
removedChannel = openChannels.remove(this);
}
IOUtils.close(serverSocket, activeChannel, () -> IOUtils.close(workerChannels),
() -> cancellableThreads.cancel("channel closed"), onClose);
() -> cancellableThreads.cancel("channel closed"));
assert removedChannel: "Channel was not removed or removed twice?";
}
}

View File

@ -21,13 +21,11 @@ package org.elasticsearch.transport.nio;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.recycler.Recycler;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
@ -83,9 +81,8 @@ public class MockNioTransport extends TcpTransport {
}
@Override
protected MockSocketChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Void> connectListener)
throws IOException {
MockSocketChannel channel = nioGroup.openChannel(node.getAddress().address(), clientChannelFactory);
protected MockSocketChannel initiateChannel(InetSocketAddress address, ActionListener<Void> connectListener) throws IOException {
MockSocketChannel channel = nioGroup.openChannel(address, clientChannelFactory);
channel.addConnectListener(ActionListener.toBiConsumer(connectListener));
return channel;
}