diff --git a/core/src/main/java/org/elasticsearch/transport/ConnectionProfile.java b/core/src/main/java/org/elasticsearch/transport/ConnectionProfile.java index b49fc72c48a..f6aa1e8445b 100644 --- a/core/src/main/java/org/elasticsearch/transport/ConnectionProfile.java +++ b/core/src/main/java/org/elasticsearch/transport/ConnectionProfile.java @@ -18,6 +18,8 @@ */ package org.elasticsearch.transport; +import org.elasticsearch.common.unit.TimeValue; + import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -42,14 +44,16 @@ public final class ConnectionProfile { TransportRequestOptions.Type.PING, TransportRequestOptions.Type.RECOVERY, TransportRequestOptions.Type.REG, - TransportRequestOptions.Type.STATE))), 1); + TransportRequestOptions.Type.STATE))), 1, null); private final List handles; private final int numConnections; + private final TimeValue connectTimeout; - private ConnectionProfile(List handles, int numConnections) { + private ConnectionProfile(List handles, int numConnections, TimeValue connectTimeout) { this.handles = handles; this.numConnections = numConnections; + this.connectTimeout = connectTimeout; } /** @@ -59,6 +63,17 @@ public final class ConnectionProfile { private final List handles = new ArrayList<>(); private final Set addedTypes = EnumSet.noneOf(TransportRequestOptions.Type.class); private int offset = 0; + private TimeValue connectTimeout; + + /** + * Sets a connect connectTimeout for this connection profile + */ + public void setConnectTimeout(TimeValue connectTimeout) { + if (connectTimeout.millis() < 0) { + throw new IllegalArgumentException("connectTimeout must be non-negative but was: " + connectTimeout); + } + this.connectTimeout = connectTimeout; + } /** * Adds a number of connections for one or more types. Each type can only be added once. @@ -89,8 +104,16 @@ public final class ConnectionProfile { if (types.isEmpty() == false) { throw new IllegalStateException("not all types are added for this connection profile - missing types: " + types); } - return new ConnectionProfile(Collections.unmodifiableList(handles), offset); + return new ConnectionProfile(Collections.unmodifiableList(handles), offset, connectTimeout); } + + } + + /** + * Returns the connect timeout or null if no explicit timeout is set on this profile. + */ + public TimeValue getConnectTimeout() { + return connectTimeout; } /** diff --git a/core/src/main/java/org/elasticsearch/transport/TcpTransport.java b/core/src/main/java/org/elasticsearch/transport/TcpTransport.java index 3aa31f3c213..6f71e6b3d49 100644 --- a/core/src/main/java/org/elasticsearch/transport/TcpTransport.java +++ b/core/src/main/java/org/elasticsearch/transport/TcpTransport.java @@ -150,7 +150,6 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i private static final long NINETY_PER_HEAP_SIZE = (long) (JvmInfo.jvmInfo().getMem().getHeapMax().getBytes() * 0.9); private static final int PING_DATA_SIZE = -1; - protected final TimeValue connectTimeout; protected final boolean blockingClient; private final CircuitBreakerService circuitBreakerService; // package visibility for tests @@ -190,9 +189,6 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i this.compress = Transport.TRANSPORT_TCP_COMPRESS.get(settings); this.networkService = networkService; this.transportName = transportName; - - - this.connectTimeout = TCP_CONNECT_TIMEOUT.get(settings); this.blockingClient = TCP_BLOCKING_CLIENT.get(settings); defaultConnectionProfile = buildDefaultConnectionProfile(settings); } @@ -204,6 +200,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i int connectionsPerNodeState = CONNECTIONS_PER_NODE_STATE.get(settings); int connectionsPerNodePing = CONNECTIONS_PER_NODE_PING.get(settings); ConnectionProfile.Builder builder = new ConnectionProfile.Builder(); + builder.setConnectTimeout(TCP_CONNECT_TIMEOUT.get(settings)); builder.addConnections(connectionsPerNodeBulk, TransportRequestOptions.Type.BULK); builder.addConnections(connectionsPerNodePing, TransportRequestOptions.Type.PING); // if we are not master eligible we don't need a dedicated channel to publish the state diff --git a/core/src/test/java/org/elasticsearch/transport/ConnectionProfileTests.java b/core/src/test/java/org/elasticsearch/transport/ConnectionProfileTests.java index c63cc135a6f..1785853d0e1 100644 --- a/core/src/test/java/org/elasticsearch/transport/ConnectionProfileTests.java +++ b/core/src/test/java/org/elasticsearch/transport/ConnectionProfileTests.java @@ -18,6 +18,7 @@ */ package org.elasticsearch.transport; +import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.test.ESTestCase; import org.hamcrest.Matchers; @@ -27,6 +28,11 @@ public class ConnectionProfileTests extends ESTestCase { public void testBuildConnectionProfile() { ConnectionProfile.Builder builder = new ConnectionProfile.Builder(); + TimeValue connectTimeout = TimeValue.timeValueMillis(randomIntBetween(1, 10)); + final boolean setConnectTimeout = randomBoolean(); + if (setConnectTimeout) { + builder.setConnectTimeout(connectTimeout); + } builder.addConnections(1, TransportRequestOptions.Type.BULK); builder.addConnections(2, TransportRequestOptions.Type.STATE, TransportRequestOptions.Type.RECOVERY); builder.addConnections(3, TransportRequestOptions.Type.PING); @@ -39,6 +45,11 @@ public class ConnectionProfileTests extends ESTestCase { builder.addConnections(4, TransportRequestOptions.Type.REG); ConnectionProfile build = builder.build(); assertEquals(10, build.getNumConnections()); + if (setConnectTimeout) { + assertEquals(connectTimeout, build.getConnectTimeout()); + } else { + assertNull(build.getConnectTimeout()); + } Integer[] array = new Integer[10]; for (int i = 0; i < array.length; i++) { array[i] = i; diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java index c3bbb2d4e1c..20dd2d1a9fc 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java @@ -55,6 +55,7 @@ import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Setting.Property; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.FutureUtils; @@ -204,7 +205,7 @@ public class Netty4Transport extends TcpTransport { bootstrap.handler(getClientChannelInitializer()); - bootstrap.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, Math.toIntExact(connectTimeout.millis())); + bootstrap.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, Math.toIntExact(defaultConnectionProfile.getConnectTimeout().millis())); bootstrap.option(ChannelOption.TCP_NODELAY, TCP_NO_DELAY.get(settings)); bootstrap.option(ChannelOption.SO_KEEPALIVE, TCP_KEEP_ALIVE.get(settings)); @@ -270,7 +271,8 @@ public class Netty4Transport extends TcpTransport { logger.debug("using profile[{}], worker_count[{}], port[{}], bind_host[{}], publish_host[{}], compress[{}], " + "connect_timeout[{}], connections_per_node[{}/{}/{}/{}/{}], receive_predictor[{}->{}]", name, workerCount, settings.get("port"), settings.get("bind_host"), settings.get("publish_host"), compress, - connectTimeout, defaultConnectionProfile.getNumConnectionsPerType(TransportRequestOptions.Type.RECOVERY), + defaultConnectionProfile.getConnectTimeout(), + defaultConnectionProfile.getNumConnectionsPerType(TransportRequestOptions.Type.RECOVERY), defaultConnectionProfile.getNumConnectionsPerType(TransportRequestOptions.Type.BULK), defaultConnectionProfile.getNumConnectionsPerType(TransportRequestOptions.Type.REG), defaultConnectionProfile.getNumConnectionsPerType(TransportRequestOptions.Type.STATE), @@ -343,7 +345,18 @@ public class Netty4Transport extends TcpTransport { final NodeChannels nodeChannels = new NodeChannels(channels, profile); boolean success = false; try { - int numConnections = channels.length; + final int numConnections = channels.length; + final TimeValue connectTimeout; + final Bootstrap bootstrap; + final TimeValue defaultConnectTimeout = defaultConnectionProfile.getConnectTimeout(); + if (profile.getConnectTimeout() != null && profile.getConnectTimeout().equals(defaultConnectTimeout) == false) { + bootstrap = this.bootstrap.clone(this.bootstrap.config().group()); + bootstrap.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, Math.toIntExact(profile.getConnectTimeout().millis())); + connectTimeout = profile.getConnectTimeout(); + } else { + connectTimeout = defaultConnectTimeout; + bootstrap = this.bootstrap; + } final ArrayList connections = new ArrayList<>(numConnections); final InetSocketAddress address = node.getAddress().address(); for (int i = 0; i < numConnections; i++) { diff --git a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java index cc32fb3f6ac..434990afba9 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java @@ -21,6 +21,7 @@ package org.elasticsearch.transport; import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.logging.log4j.util.Supplier; +import org.apache.lucene.util.Constants; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListenerResponseHandler; @@ -44,7 +45,12 @@ import org.junit.After; import org.junit.Before; import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.ServerSocket; +import java.sql.Time; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -1721,4 +1727,46 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { serviceA.registerRequestHandler("action1", TestRequest::new, randomFrom(ThreadPool.Names.SAME, ThreadPool.Names.GENERIC), (request, message) -> {throw new AssertionError("boom");}); } + + public void testTimeoutPerConnection() throws IOException { + assumeTrue("Works only on BSD network stacks and apparently windows", + Constants.MAC_OS_X || Constants.FREE_BSD || Constants.WINDOWS); + try (ServerSocket socket = new ServerSocket()) { + // note - this test uses backlog=1 which is implementation specific ie. it might not work on some TCP/IP stacks + // on linux (at least newer ones) the listen(addr, backlog=1) should just ignore new connections if the queue is full which + // means that once we received an ACK from the client we just drop the packet on the floor (which is what we want) and we run + // into a connection timeout quickly. Yet other implementations can for instance can terminate the connection within the 3 way + // handshake which I haven't tested yet. + socket.bind(new InetSocketAddress(InetAddress.getLocalHost(), 0), 1); + socket.setReuseAddress(true); + DiscoveryNode first = new DiscoveryNode("TEST", new TransportAddress(socket.getInetAddress(), + socket.getLocalPort()), emptyMap(), + emptySet(), version0); + DiscoveryNode second = new DiscoveryNode("TEST", new TransportAddress(socket.getInetAddress(), + socket.getLocalPort()), emptyMap(), + emptySet(), version0); + ConnectionProfile.Builder builder = new ConnectionProfile.Builder(); + builder.addConnections(1, + TransportRequestOptions.Type.BULK, + TransportRequestOptions.Type.PING, + TransportRequestOptions.Type.RECOVERY, + TransportRequestOptions.Type.REG, + TransportRequestOptions.Type.STATE); + + // connection with one connection and a large timeout -- should consume the one spot in the backlog queue + serviceA.connectToNode(first, builder.build()); + builder.setConnectTimeout(TimeValue.timeValueMillis(1)); + final ConnectionProfile profile = builder.build(); + // now with the 1ms timeout we got and test that is it's applied + long startTime = System.nanoTime(); + ConnectTransportException ex = expectThrows(ConnectTransportException.class, () -> { + serviceA.connectToNode(second, profile); + }); + final long now = System.nanoTime(); + final long timeTaken = TimeValue.nsecToMSec(now - startTime); + assertTrue("test didn't timeout quick enough, time taken: [" + timeTaken + "]", + timeTaken < TimeValue.timeValueSeconds(5).millis()); + assertEquals(ex.getMessage(), "[][" + second.getAddress() + "] connect_timeout[1ms]"); + } + } } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java b/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java index fc33ce3c635..e9a97e030b2 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java @@ -30,6 +30,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.CancellableThreads; import org.elasticsearch.common.util.concurrent.AbstractRunnable; @@ -46,6 +47,7 @@ import java.net.InetSocketAddress; import java.net.ServerSocket; import java.net.Socket; import java.net.SocketException; +import java.net.SocketTimeoutException; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -178,7 +180,13 @@ public class MockTcpTransport extends TcpTransport final InetSocketAddress address = node.getAddress().address(); // we just use a single connections configureSocket(socket); - socket.connect(address, (int) TCP_CONNECT_TIMEOUT.get(settings).millis()); + final TimeValue connectTimeout = profile.getConnectTimeout() == null ? defaultConnectionProfile.getConnectTimeout() + : profile.getConnectTimeout(); + try { + socket.connect(address, Math.toIntExact(connectTimeout.millis())); + } catch (SocketTimeoutException ex) { + throw new ConnectTransportException(node, "connect_timeout[" + connectTimeout + "]", ex); + } MockChannel channel = new MockChannel(socket, address, "none", onClose); channel.loopRead(executor); mockChannels[0] = channel;