diff --git a/core/src/main/java/org/elasticsearch/transport/TcpTransport.java b/core/src/main/java/org/elasticsearch/transport/TcpTransport.java index 31d871a2ae8..5db64c98b85 100644 --- a/core/src/main/java/org/elasticsearch/transport/TcpTransport.java +++ b/core/src/main/java/org/elasticsearch/transport/TcpTransport.java @@ -87,6 +87,7 @@ import java.util.Collections; import java.util.EnumMap; import java.util.HashMap; import java.util.HashSet; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; @@ -113,7 +114,6 @@ import static org.elasticsearch.common.settings.Setting.timeSetting; import static org.elasticsearch.common.transport.NetworkExceptionHelper.isCloseConnectionException; import static org.elasticsearch.common.transport.NetworkExceptionHelper.isConnectException; import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentMap; -import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentSet; public abstract class TcpTransport extends AbstractLifecycleComponent implements Transport { @@ -161,7 +161,6 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i protected volatile TransportServiceAdapter transportServiceAdapter; // node id to actual channel protected final ConcurrentMap connectedNodes = newConcurrentMap(); - private final Set openConnections = newConcurrentSet(); protected final Map> serverChannels = newConcurrentMap(); protected final ConcurrentMap profileBoundAddresses = newConcurrentMap(); @@ -171,7 +170,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i // this lock is here to make sure we close this transport and disconnect all the client nodes // connections while no connect operations is going on... (this might help with 100% CPU when stopping the transport?) - protected final ReadWriteLock globalLock = new ReentrantReadWriteLock(); + protected final ReadWriteLock closeLock = new ReentrantReadWriteLock(); protected final boolean compress; protected volatile BoundTransportAddress boundAddress; private final String transportName; @@ -390,15 +389,6 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i return version; } - public boolean hasChannel(Channel channel) { - for (Channel channel1 : channels) { - if (channel.equals(channel1)) { - return true; - } - } - return false; - } - public List getChannels() { return Arrays.asList(channels); } @@ -412,12 +402,12 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i } @Override - public synchronized void close() throws IOException { + public void close() throws IOException { if (closed.compareAndSet(false, true)) { try { closeChannels(Arrays.stream(channels).filter(Objects::nonNull).collect(Collectors.toList())); } finally { - onNodeChannelsClosed(this); + transportServiceAdapter.onConnectionClosed(this); } } } @@ -436,6 +426,10 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i Channel channel = channel(options.type()); sendRequestToChannel(this.node, channel, requestId, action, request, options, getVersion(), (byte) 0); } + + boolean isClosed() { + return closed.get(); + } } @Override @@ -451,7 +445,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i if (node == null) { throw new ConnectTransportException(null, "can't connect to a null node"); } - globalLock.readLock().lock(); // ensure we don't open connections while we are closing + closeLock.readLock().lock(); // ensure we don't open connections while we are closing try { ensureOpen(); try (Releasable ignored = connectionLock.acquire(node.getId())) { @@ -468,7 +462,24 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i if (logger.isDebugEnabled()) { logger.debug("connected to node [{}]", node); } - transportServiceAdapter.onNodeConnected(node); + try { + transportServiceAdapter.onNodeConnected(node); + } finally { + if (nodeChannels.isClosed()) { + // we got closed concurrently due to a disconnect or some other event on the channel. + // the close callback will close the NodeChannel instance first and then try to remove + // the connection from the connected nodes. It will NOT acquire the connectionLock for + // the node to prevent any blocking calls on network threads. Yet, we still establish a happens + // before relationship to the connectedNodes.put since we check if we can remove the + // (DiscoveryNode, NodeChannels) tuple from the map after we closed. Here we check if it's closed an if so we + // try to remove it first either way one of the two wins even if the callback has run before we even added the + // tuple to the map since in that case we remove it here again + if (connectedNodes.remove(node, nodeChannels)) { + transportServiceAdapter.onNodeDisconnected(node); + } + throw new NodeNotConnectedException(node, "connection concurrently closed"); + } + } success = true; } catch (ConnectTransportException e) { throw e; @@ -484,7 +495,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i } } } finally { - globalLock.readLock().unlock(); + closeLock.readLock().unlock(); } } @@ -519,11 +530,12 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i boolean success = false; NodeChannels nodeChannels = null; connectionProfile = resolveConnectionProfile(connectionProfile, defaultConnectionProfile); - globalLock.readLock().lock(); // ensure we don't open connections while we are closing + closeLock.readLock().lock(); // ensure we don't open connections while we are closing try { ensureOpen(); try { - AtomicBoolean runOnce = new AtomicBoolean(false); + final AtomicBoolean runOnce = new AtomicBoolean(false); + final AtomicReference connectionRef = new AtomicReference<>(); Consumer onClose = c -> { assert isOpen(c) == false : "channel is still open when onClose is called"; try { @@ -532,7 +544,10 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i // we only need to disconnect from the nodes once since all other channels // will also try to run this we protect it from running multiple times. if (runOnce.compareAndSet(false, true)) { - disconnectFromNodeChannel(c, "channel closed"); + NodeChannels connection = connectionRef.get(); + if (connection != null) { + disconnectFromNodeCloseAndNotify(node, connection); + } } } }; @@ -546,7 +561,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i final Version version = executeHandshake(node, channel, handshakeTimeout); nodeChannels = new NodeChannels(nodeChannels, version); // clone the channels - we now have the correct version transportServiceAdapter.onConnectionOpened(nodeChannels); - openConnections.add(nodeChannels); + connectionRef.set(nodeChannels); success = true; return nodeChannels; } catch (ConnectTransportException e) { @@ -561,77 +576,38 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i } } } finally { - globalLock.readLock().unlock(); + closeLock.readLock().unlock(); } } - /** - * Disconnects from a node, only if the relevant channel is found to be part of the node channels. - */ - protected boolean disconnectFromNode(DiscoveryNode node, Channel channel, String reason) { - // this might be called multiple times from all the node channels, so do a lightweight - // check outside of the lock - NodeChannels nodeChannels = connectedNodes.get(node); - if (nodeChannels != null && nodeChannels.hasChannel(channel)) { - try (Releasable ignored = connectionLock.acquire(node.getId())) { - nodeChannels = connectedNodes.get(node); - // check again within the connection lock, if its still applicable to remove it - if (nodeChannels != null && nodeChannels.hasChannel(channel)) { - connectedNodes.remove(node); - closeAndNotify(node, nodeChannels, reason); - return true; - } - } - } - return false; - } - - private void closeAndNotify(DiscoveryNode node, NodeChannels nodeChannels, String reason) { + private void disconnectFromNodeCloseAndNotify(DiscoveryNode node, NodeChannels nodeChannels) { + assert nodeChannels != null : "nodeChannels must not be null"; try { - logger.debug("disconnecting from [{}], {}", node, reason); IOUtils.closeWhileHandlingException(nodeChannels); } finally { - logger.trace("disconnected from [{}], {}", node, reason); - transportServiceAdapter.onNodeDisconnected(node); + if (closeLock.readLock().tryLock()) { + try { + if (connectedNodes.remove(node, nodeChannels)) { + transportServiceAdapter.onNodeDisconnected(node); + } + } finally { + closeLock.readLock().unlock(); + } + } } } /** * Disconnects from a node if a channel is found as part of that nodes channels. */ - protected final void disconnectFromNodeChannel(final Channel channel, final String reason) { - threadPool.generic().execute(() -> { + protected final void closeChannelWhileHandlingExceptions(final Channel channel) { + if (isOpen(channel)) { try { - if (isOpen(channel)) { - closeChannels(Collections.singletonList(channel)); - } + closeChannels(Collections.singletonList(channel)); } catch (IOException e) { logger.warn("failed to close channel", e); - } finally { - outer: - { - for (Map.Entry entry : connectedNodes.entrySet()) { - if (disconnectFromNode(entry.getKey(), channel, reason)) { - // if we managed to find this channel and disconnect from it, then break, no need to check on - // the rest of the nodes - // #onNodeChannelsClosed will remove it.. - assert openConnections.contains(entry.getValue()) == false : "NodeChannel#close should remove the connetion"; - // we can only be connected and published to a single node with one connection. So if disconnectFromNode - // returns true we can safely break out from here since we cleaned up everything needed - break outer; - } - } - // now if we haven't found the right connection in the connected nodes we have to go through the open connections - // it might be that the channel belongs to a connection that is not published - for (NodeChannels channels : openConnections) { - if (channels.hasChannel(channel)) { - IOUtils.closeWhileHandlingException(channels); - break; - } - } - } } - }); + } } @Override @@ -645,10 +621,14 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i @Override public void disconnectFromNode(DiscoveryNode node) { + closeLock.readLock().lock(); + NodeChannels nodeChannels = null; try (Releasable ignored = connectionLock.acquire(node.getId())) { - NodeChannels nodeChannels = connectedNodes.remove(node); - if (nodeChannels != null) { - closeAndNotify(node, nodeChannels, "due to explicit disconnect call"); + nodeChannels = connectedNodes.remove(node); + } finally { + closeLock.readLock().unlock(); + if (nodeChannels != null) { // if we found it and removed it we close and notify + IOUtils.closeWhileHandlingException(nodeChannels, () -> transportServiceAdapter.onNodeDisconnected(node)); } } } @@ -921,7 +901,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i final CountDownLatch latch = new CountDownLatch(1); // make sure we run it on another thread than a possible IO handler thread threadPool.generic().execute(() -> { - globalLock.writeLock().lock(); + closeLock.writeLock().lock(); try { // first stop to accept any incoming connections so nobody can connect to this transport for (Map.Entry> entry : serverChannels.entrySet()) { @@ -935,12 +915,19 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i } // we are holding a write lock so nobody modifies the connectedNodes / openConnections map - it's safe to first close // all instances and then clear them maps - IOUtils.closeWhileHandlingException(Iterables.concat(connectedNodes.values(), openConnections)); - openConnections.clear(); - connectedNodes.clear(); + Iterator> iterator = connectedNodes.entrySet().iterator(); + while (iterator.hasNext()) { + Map.Entry next = iterator.next(); + try { + IOUtils.closeWhileHandlingException(next.getValue()); + transportServiceAdapter.onNodeDisconnected(next.getKey()); + } finally { + iterator.remove(); + } + } stopInternal(); } finally { - globalLock.writeLock().unlock(); + closeLock.writeLock().unlock(); latch.countDown(); } }); @@ -954,10 +941,9 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i } protected void onException(Channel channel, Exception e) { - String reason = ExceptionsHelper.detailedMessage(e); if (!lifecycle.started()) { // just close and ignore - we are already stopped and just need to make sure we release all resources - disconnectFromNodeChannel(channel, reason); + closeChannelWhileHandlingExceptions(channel); return; } @@ -968,15 +954,15 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i channel), e); // close the channel, which will cause a node to be disconnected if relevant - disconnectFromNodeChannel(channel, reason); + closeChannelWhileHandlingExceptions(channel); } else if (isConnectException(e)) { logger.trace((Supplier) () -> new ParameterizedMessage("connect exception caught on transport layer [{}]", channel), e); // close the channel as safe measure, which will cause a node to be disconnected if relevant - disconnectFromNodeChannel(channel, reason); + closeChannelWhileHandlingExceptions(channel); } else if (e instanceof BindException) { logger.trace((Supplier) () -> new ParameterizedMessage("bind exception caught on transport layer [{}]", channel), e); // close the channel as safe measure, which will cause a node to be disconnected if relevant - disconnectFromNodeChannel(channel, reason); + closeChannelWhileHandlingExceptions(channel); } else if (e instanceof CancelledKeyException) { logger.trace( (Supplier) () -> new ParameterizedMessage( @@ -984,7 +970,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i channel), e); // close the channel as safe measure, which will cause a node to be disconnected if relevant - disconnectFromNodeChannel(channel, reason); + closeChannelWhileHandlingExceptions(channel); } else if (e instanceof TcpTransport.HttpOnTransportException) { // in case we are able to return data, serialize the exception content and sent it back to the client if (isOpen(channel)) { @@ -1015,7 +1001,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i logger.warn( (Supplier) () -> new ParameterizedMessage("exception caught on transport layer [{}], closing connection", channel), e); // close the channel, which will cause a node to be disconnected if relevant - disconnectFromNodeChannel(channel, reason); + closeChannelWhileHandlingExceptions(channel); } } @@ -1712,22 +1698,6 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i } } - private void onNodeChannelsClosed(NodeChannels channels) { - // don't assert here since the channel / connection might not have been registered yet - final boolean remove = openConnections.remove(channels); - if (remove) { - transportServiceAdapter.onConnectionClosed(channels); - } - } - - final int getNumOpenConnections() { - return openConnections.size(); - } - - final int getNumConnectedNodes() { - return connectedNodes.size(); - } - /** * Returns count of currently open connections */ diff --git a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java index da730ee5645..206cfeeb62e 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java @@ -169,8 +169,6 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { try { assertNoPendingHandshakes(serviceA.getOriginalTransport()); assertNoPendingHandshakes(serviceB.getOriginalTransport()); - assertPendingConnections(0, serviceA.getOriginalTransport()); - assertPendingConnections(0, serviceB.getOriginalTransport()); } finally { IOUtils.close(serviceA, serviceB, () -> { try { @@ -194,12 +192,6 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { } } - public void assertPendingConnections(int numConnections, Transport transport) { - if (transport instanceof TcpTransport) { - TcpTransport tcpTransport = (TcpTransport) transport; - assertEquals(numConnections, tcpTransport.getNumOpenConnections() - tcpTransport.getNumConnectedNodes()); - } - } public void testHelloWorld() { serviceA.registerRequestHandler("sayHello", StringMessageRequest::new, ThreadPool.Names.GENERIC, @@ -2243,13 +2235,10 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { serviceB.sendRequest(connection, "action", new TestRequest("hello world"), TransportRequestOptions.EMPTY, transportResponseHandler); receivedLatch.await(); - assertPendingConnections(1, serviceB.getOriginalTransport()); serviceC.close(); - assertPendingConnections(0, serviceC.getOriginalTransport()); sendResponseLatch.countDown(); responseLatch.await(); } - assertPendingConnections(0, serviceC.getOriginalTransport()); } public void testTransportStats() throws Exception { @@ -2341,11 +2330,7 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { assertEquals(46, stats.getRxSize().getBytes()); assertEquals(91, stats.getTxSize().getBytes()); } finally { - try { - assertPendingConnections(0, serviceC.getOriginalTransport()); - } finally { - serviceC.close(); - } + serviceC.close(); } } @@ -2443,11 +2428,7 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { assertEquals(185 + addressLen, stats.getRxSize().getBytes()); assertEquals(91, stats.getTxSize().getBytes()); } finally { - try { - assertPendingConnections(0, serviceC.getOriginalTransport()); - } finally { - serviceC.close(); - } + serviceC.close(); } } }