From 186c16ea41406b284bce896ab23771a93e93e7ec Mon Sep 17 00:00:00 2001 From: Simon Willnauer Date: Tue, 13 Jun 2017 09:37:05 +0200 Subject: [PATCH] Ensure pending transport handlers are invoked for all channel failures (#25150) Today if a channel gets closed due to a disconnect we notify the response handler that the connection is closed and the node is disconnected. Unfortunately this is not a complete solution since it only works for published connections. Connections that are unpublished ie. for discovery can indefinitely hang since we never invoke their handers when we get a failure while a user is waiting for the response. This change adds connection tracking to TcpTransport that ensures we are notifying the corresponding connection if there is a failure on a channel. --- .../elasticsearch/transport/TcpTransport.java | 132 ++++++++++------ .../transport/TCPTransportTests.java | 7 +- .../transport/netty4/Netty4Transport.java | 27 +--- .../AbstractSimpleTransportTestCase.java | 143 ++++++++++++++---- .../transport/MockTcpTransport.java | 22 +-- 5 files changed, 217 insertions(+), 114 deletions(-) diff --git a/core/src/main/java/org/elasticsearch/transport/TcpTransport.java b/core/src/main/java/org/elasticsearch/transport/TcpTransport.java index 209127b1cdd..22aced389f8 100644 --- a/core/src/main/java/org/elasticsearch/transport/TcpTransport.java +++ b/core/src/main/java/org/elasticsearch/transport/TcpTransport.java @@ -66,6 +66,7 @@ import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; import org.elasticsearch.common.util.concurrent.KeyedLock; import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.common.util.iterable.Iterables; import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.monitor.jvm.JvmInfo; import org.elasticsearch.rest.RestStatus; @@ -85,7 +86,6 @@ import java.util.Collections; import java.util.EnumMap; import java.util.HashMap; import java.util.HashSet; -import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; @@ -112,6 +112,7 @@ import static org.elasticsearch.common.settings.Setting.timeSetting; import static org.elasticsearch.common.transport.NetworkExceptionHelper.isCloseConnectionException; import static org.elasticsearch.common.transport.NetworkExceptionHelper.isConnectException; import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentMap; +import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentSet; public abstract class TcpTransport extends AbstractLifecycleComponent implements Transport { @@ -159,6 +160,8 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i protected volatile TransportServiceAdapter transportServiceAdapter; // node id to actual channel protected final ConcurrentMap connectedNodes = newConcurrentMap(); + private final Set openConnections = newConcurrentSet(); + protected final Map> serverChannels = newConcurrentMap(); protected final ConcurrentMap profileBoundAddresses = newConcurrentMap(); @@ -357,9 +360,8 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i private final DiscoveryNode node; private final AtomicBoolean closed = new AtomicBoolean(false); private final Version version; - private final Consumer onClose; - public NodeChannels(DiscoveryNode node, Channel[] channels, ConnectionProfile connectionProfile, Consumer onClose) { + public NodeChannels(DiscoveryNode node, Channel[] channels, ConnectionProfile connectionProfile) { this.node = node; this.channels = channels; assert channels.length == connectionProfile.getNumConnections() : "expected channels size to be == " @@ -370,7 +372,6 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i typeMapping.put(type, handle); } version = node.getVersion(); - this.onClose = onClose; } NodeChannels(NodeChannels channels, Version handshakeVersion) { @@ -378,7 +379,6 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i this.channels = channels.channels; this.typeMapping = channels.typeMapping; this.version = handshakeVersion; - this.onClose = channels.onClose; } @Override @@ -413,7 +413,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i try { closeChannels(Arrays.stream(channels).filter(Objects::nonNull).collect(Collectors.toList())); } finally { - onClose.accept(this); + onNodeChannelsClosed(this); } } } @@ -455,27 +455,28 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i if (nodeChannels != null) { return; } + boolean success = false; try { - try { - nodeChannels = openConnection(node, connectionProfile); - connectionValidator.accept(nodeChannels, connectionProfile); - } catch (Exception e) { - logger.trace( - (Supplier) () -> new ParameterizedMessage( - "failed to connect to [{}], cleaning dangling connections", node), e); - IOUtils.closeWhileHandlingException(nodeChannels); - throw e; - } + nodeChannels = openConnection(node, connectionProfile); + connectionValidator.accept(nodeChannels, connectionProfile); // we acquire a connection lock, so no way there is an existing connection connectedNodes.put(node, nodeChannels); if (logger.isDebugEnabled()) { logger.debug("connected to node [{}]", node); } transportServiceAdapter.onNodeConnected(node); + success = true; } catch (ConnectTransportException e) { throw e; } catch (Exception e) { throw new ConnectTransportException(node, "general node connection failure", e); + } finally { + if (success == false) { // close the connection if there is a failure + logger.trace( + (Supplier) () -> new ParameterizedMessage( + "failed to connect to [{}], cleaning dangling connections", node)); + IOUtils.closeWhileHandlingException(nodeChannels); + } } } } finally { @@ -518,7 +519,20 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i try { ensureOpen(); try { - nodeChannels = connectToChannels(node, connectionProfile); + AtomicBoolean runOnce = new AtomicBoolean(false); + Consumer onClose = c -> { + assert isOpen(c) == false : "channel is still open when onClose is called"; + try { + onChannelClosed(c); + } finally { + // we only need to disconnect from the nodes once since all other channels + // will also try to run this we protect it from running multiple times. + if (runOnce.compareAndSet(false, true)) { + disconnectFromNodeChannel(c, "channel closed"); + } + } + }; + nodeChannels = connectToChannels(node, connectionProfile, onClose); final Channel channel = nodeChannels.getChannels().get(0); // one channel is guaranteed by the connection profile final TimeValue connectTimeout = connectionProfile.getConnectTimeout() == null ? defaultConnectionProfile.getConnectTimeout() : @@ -526,8 +540,9 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i final TimeValue handshakeTimeout = connectionProfile.getHandshakeTimeout() == null ? connectTimeout : connectionProfile.getHandshakeTimeout(); final Version version = executeHandshake(node, channel, handshakeTimeout); - transportServiceAdapter.onConnectionOpened(nodeChannels); nodeChannels = new NodeChannels(nodeChannels, version); // clone the channels - we now have the correct version + transportServiceAdapter.onConnectionOpened(nodeChannels); + openConnections.add(nodeChannels); success = true; return nodeChannels; } catch (ConnectTransportException e) { @@ -580,24 +595,37 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i /** * Disconnects from a node if a channel is found as part of that nodes channels. */ - protected final void disconnectFromNodeChannel(final Channel channel, final Exception failure) { + protected final void disconnectFromNodeChannel(final Channel channel, final String reason) { threadPool.generic().execute(() -> { try { - try { + if (isOpen(channel)) { closeChannels(Collections.singletonList(channel)); - } finally { - for (DiscoveryNode node : connectedNodes.keySet()) { - if (disconnectFromNode(node, channel, ExceptionsHelper.detailedMessage(failure))) { - // if we managed to find this channel and disconnect from it, then break, no need to check on - // the rest of the nodes - break; - } - } } } catch (IOException e) { logger.warn("failed to close channel", e); } finally { - onChannelClosed(channel); + outer: + { + for (Map.Entry entry : connectedNodes.entrySet()) { + if (disconnectFromNode(entry.getKey(), channel, reason)) { + // if we managed to find this channel and disconnect from it, then break, no need to check on + // the rest of the nodes + // #onNodeChannelsClosed will remove it.. + assert openConnections.contains(entry.getValue()) == false : "NodeChannel#close should remove the connetion"; + // we can only be connected and published to a single node with one connection. So if disconnectFromNode + // returns true we can safely break out from here since we cleaned up everything needed + break outer; + } + } + // now if we haven't found the right connection in the connected nodes we have to go through the open connections + // it might be that the channel belongs to a connection that is not published + for (NodeChannels channels : openConnections) { + if (channels.hasChannel(channel)) { + IOUtils.closeWhileHandlingException(channels); + break; + } + } + } } }); } @@ -901,12 +929,11 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i "Error closing serverChannel for profile [{}]", entry.getKey()), e); } } - - for (Iterator it = connectedNodes.values().iterator(); it.hasNext();) { - NodeChannels nodeChannels = it.next(); - it.remove(); - IOUtils.closeWhileHandlingException(nodeChannels); - } + // we are holding a write lock so nobody modifies the connectedNodes / openConnections map - it's safe to first close + // all instances and then clear them maps + IOUtils.closeWhileHandlingException(Iterables.concat(connectedNodes.values(), openConnections)); + openConnections.clear(); + connectedNodes.clear(); stopInternal(); } finally { globalLock.writeLock().unlock(); @@ -923,11 +950,13 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i } protected void onException(Channel channel, Exception e) { + String reason = ExceptionsHelper.detailedMessage(e); if (!lifecycle.started()) { // just close and ignore - we are already stopped and just need to make sure we release all resources - disconnectFromNodeChannel(channel, e); + disconnectFromNodeChannel(channel, reason); return; } + if (isCloseConnectionException(e)) { logger.trace( (Supplier) () -> new ParameterizedMessage( @@ -935,15 +964,15 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i channel), e); // close the channel, which will cause a node to be disconnected if relevant - disconnectFromNodeChannel(channel, e); + disconnectFromNodeChannel(channel, reason); } else if (isConnectException(e)) { logger.trace((Supplier) () -> new ParameterizedMessage("connect exception caught on transport layer [{}]", channel), e); // close the channel as safe measure, which will cause a node to be disconnected if relevant - disconnectFromNodeChannel(channel, e); + disconnectFromNodeChannel(channel, reason); } else if (e instanceof BindException) { logger.trace((Supplier) () -> new ParameterizedMessage("bind exception caught on transport layer [{}]", channel), e); // close the channel as safe measure, which will cause a node to be disconnected if relevant - disconnectFromNodeChannel(channel, e); + disconnectFromNodeChannel(channel, reason); } else if (e instanceof CancelledKeyException) { logger.trace( (Supplier) () -> new ParameterizedMessage( @@ -951,7 +980,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i channel), e); // close the channel as safe measure, which will cause a node to be disconnected if relevant - disconnectFromNodeChannel(channel, e); + disconnectFromNodeChannel(channel, reason); } else if (e instanceof TcpTransport.HttpOnTransportException) { // in case we are able to return data, serialize the exception content and sent it back to the client if (isOpen(channel)) { @@ -981,7 +1010,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i logger.warn( (Supplier) () -> new ParameterizedMessage("exception caught on transport layer [{}], closing connection", channel), e); // close the channel, which will cause a node to be disconnected if relevant - disconnectFromNodeChannel(channel, e); + disconnectFromNodeChannel(channel, reason); } } @@ -1012,7 +1041,8 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i */ protected abstract void sendMessage(Channel channel, BytesReference reference, ActionListener listener); - protected abstract NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile connectionProfile) throws IOException; + protected abstract NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile connectionProfile, + Consumer onChannelClose) throws IOException; /** * Called to tear down internal resources @@ -1607,7 +1637,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i /** * Called once the channel is closed for instance due to a disconnect or a closed socket etc. */ - protected final void onChannelClosed(Channel channel) { + private void onChannelClosed(Channel channel) { final Optional first = pendingHandshakes.entrySet().stream() .filter((entry) -> entry.getValue().channel == channel).map((e) -> e.getKey()).findFirst(); if (first.isPresent()) { @@ -1655,4 +1685,20 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i Releasables.close(optionalReleasable, transportAdaptorCallback::run); } } + + private void onNodeChannelsClosed(NodeChannels channels) { + // don't assert here since the channel / connection might not have been registered yet + final boolean remove = openConnections.remove(channels); + if (remove) { + transportServiceAdapter.onConnectionClosed(channels); + } + } + + final int getNumOpenConnections() { + return openConnections.size(); + } + + final int getNumConnectedNodes() { + return connectedNodes.size(); + } } diff --git a/core/src/test/java/org/elasticsearch/transport/TCPTransportTests.java b/core/src/test/java/org/elasticsearch/transport/TCPTransportTests.java index d8e35bd6f1a..a68416cc25a 100644 --- a/core/src/test/java/org/elasticsearch/transport/TCPTransportTests.java +++ b/core/src/test/java/org/elasticsearch/transport/TCPTransportTests.java @@ -224,8 +224,9 @@ public class TCPTransportTests extends ESTestCase { } @Override - protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile) throws IOException { - return new NodeChannels(node, new Object[profile.getNumConnections()], profile, c -> {}); + protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile, + Consumer onChannelClose) throws IOException { + return new NodeChannels(node, new Object[profile.getNumConnections()], profile); } @Override @@ -241,7 +242,7 @@ public class TCPTransportTests extends ESTestCase { @Override public NodeChannels getConnection(DiscoveryNode node) { return new NodeChannels(node, new Object[MockTcpTransport.LIGHT_PROFILE.getNumConnections()], - MockTcpTransport.LIGHT_PROFILE, c -> {}); + MockTcpTransport.LIGHT_PROFILE); } }; DiscoveryNode node = new DiscoveryNode("foo", buildNewFakeTransportAddress(), Version.CURRENT); diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java index 0f7b0416d68..abe0739c243 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java @@ -58,7 +58,6 @@ import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.FutureUtils; import org.elasticsearch.indices.breaker.CircuitBreakerService; -import org.elasticsearch.monitor.jvm.JvmInfo; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.ConnectTransportException; import org.elasticsearch.transport.ConnectionProfile; @@ -74,7 +73,6 @@ import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.concurrent.ExecutionException; import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; @@ -314,9 +312,9 @@ public class Netty4Transport extends TcpTransport { } @Override - protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile) { + protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile, Consumer onChannelClose) { final Channel[] channels = new Channel[profile.getNumConnections()]; - final NodeChannels nodeChannels = new NodeChannels(node, channels, profile, transportServiceAdapter::onConnectionClosed); + final NodeChannels nodeChannels = new NodeChannels(node, channels, profile); boolean success = false; try { final TimeValue connectTimeout; @@ -336,6 +334,7 @@ public class Netty4Transport extends TcpTransport { connections.add(bootstrap.connect(address)); } final Iterator iterator = connections.iterator(); + final ChannelFutureListener closeListener = future -> onChannelClose.accept(future.channel()); try { for (int i = 0; i < channels.length; i++) { assert iterator.hasNext(); @@ -345,7 +344,7 @@ public class Netty4Transport extends TcpTransport { throw new ConnectTransportException(node, "connect_timeout[" + connectTimeout + "]", future.cause()); } channels[i] = future.channel(); - channels[i].closeFuture().addListener(new ChannelCloseListener(node)); + channels[i].closeFuture().addListener(closeListener); } assert iterator.hasNext() == false : "not all created connection have been consumed"; } catch (final RuntimeException e) { @@ -374,24 +373,6 @@ public class Netty4Transport extends TcpTransport { return nodeChannels; } - private class ChannelCloseListener implements ChannelFutureListener { - - private final DiscoveryNode node; - - private ChannelCloseListener(DiscoveryNode node) { - this.node = node; - } - - @Override - public void operationComplete(final ChannelFuture future) throws Exception { - onChannelClosed(future.channel()); - NodeChannels nodeChannels = connectedNodes.get(node); - if (nodeChannels != null && nodeChannels.hasChannel(future.channel())) { - threadPool.generic().execute(() -> disconnectFromNode(node, future.channel(), "channel closed event")); - } - } - } - @Override protected void sendMessage(Channel channel, BytesReference reference, ActionListener listener) { final ChannelFuture future = channel.writeAndFlush(Netty4Utils.toByteBuf(reference)); diff --git a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java index 48d90e3ec63..99704235cc7 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java @@ -41,6 +41,7 @@ import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; +import org.elasticsearch.common.util.concurrent.CountDown; import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; import org.elasticsearch.mocksocket.MockServerSocket; import org.elasticsearch.node.Node; @@ -60,6 +61,7 @@ import java.net.ServerSocket; import java.net.Socket; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -167,6 +169,8 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { try { assertNoPendingHandshakes(serviceA.getOriginalTransport()); assertNoPendingHandshakes(serviceB.getOriginalTransport()); + assertPendingConnections(0, serviceA.getOriginalTransport()); + assertPendingConnections(0, serviceB.getOriginalTransport()); } finally { IOUtils.close(serviceA, serviceB, () -> { try { @@ -190,6 +194,13 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { } } + public void assertPendingConnections(int numConnections, Transport transport) { + if (transport instanceof TcpTransport) { + TcpTransport tcpTransport = (TcpTransport) transport; + assertEquals(numConnections, tcpTransport.getNumOpenConnections() - tcpTransport.getNumConnectedNodes()); + } + } + public void testHelloWorld() { serviceA.registerRequestHandler("sayHello", StringMessageRequest::new, ThreadPool.Names.GENERIC, (request, channel) -> { @@ -748,11 +759,9 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { public void testNotifyOnShutdown() throws Exception { final CountDownLatch latch2 = new CountDownLatch(1); - - serviceA.registerRequestHandler("foobar", StringMessageRequest::new, ThreadPool.Names.GENERIC, - new TransportRequestHandler() { - @Override - public void messageReceived(StringMessageRequest request, TransportChannel channel) { + try { + serviceA.registerRequestHandler("foobar", StringMessageRequest::new, ThreadPool.Names.GENERIC, + (request, channel) -> { try { latch2.await(); logger.info("Stop ServiceB now"); @@ -760,16 +769,19 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { } catch (Exception e) { fail(e.getMessage()); } - } - }); - TransportFuture foobar = serviceB.submitRequest(nodeA, "foobar", - new StringMessageRequest(""), TransportRequestOptions.EMPTY, EmptyTransportResponseHandler.INSTANCE_SAME); - latch2.countDown(); - try { - foobar.txGet(); - fail("TransportException expected"); - } catch (TransportException ex) { + }); + TransportFuture foobar = serviceB.submitRequest(nodeA, "foobar", + new StringMessageRequest(""), TransportRequestOptions.EMPTY, EmptyTransportResponseHandler.INSTANCE_SAME); + latch2.countDown(); + try { + foobar.txGet(); + fail("TransportException expected"); + } catch (TransportException ex) { + } + } finally { + serviceB.close(); // make sure we are fully closed here otherwise we might run into assertions down the road + serviceA.disconnectFromNode(nodeB); } } @@ -1469,12 +1481,9 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { public void testMockUnresponsiveRule() throws IOException { serviceA.registerRequestHandler("sayHello", StringMessageRequest::new, ThreadPool.Names.GENERIC, - new TransportRequestHandler() { - @Override - public void messageReceived(StringMessageRequest request, TransportChannel channel) throws Exception { - assertThat("moshe", equalTo(request.message)); - throw new RuntimeException("bad message !!!"); - } + (request, channel) -> { + assertThat("moshe", equalTo(request.message)); + throw new RuntimeException("bad message !!!"); }); serviceB.addUnresponsiveRule(serviceA); @@ -1852,7 +1861,11 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { } logger.debug("DONE"); serviceC.close(); - + // when we close C here we have to disconnect the service otherwise assertions mit trip with pending connections in tearDown + // since the disconnect will then happen concurrently and that might confuse the assertions since we disconnect due to a + // connection reset by peer or other exceptions depending on the implementation + serviceB.disconnectFromNode(nodeC); + serviceA.disconnectFromNode(nodeC); } public void testRegisterHandlerTwice() { @@ -2137,7 +2150,12 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { @Override public void handleException(TransportException exp) { try { - assertTrue(exp.getClass().toString(), exp instanceof NodeDisconnectedException); + if (exp instanceof SendRequestTransportException) { + assertTrue(exp.getCause().getClass().toString(), exp.getCause() instanceof NodeNotConnectedException); + } else { + // here the concurrent disconnect was faster and invoked the listener first + assertTrue(exp.getClass().toString(), exp instanceof NodeDisconnectedException); + } } finally { latch.countDown(); } @@ -2155,12 +2173,83 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { TransportRequestOptions.Type.RECOVERY, TransportRequestOptions.Type.REG, TransportRequestOptions.Type.STATE); - Transport.Connection connection = serviceB.openConnection(serviceC.getLocalNode(), builder.build()); - serviceB.sendRequest(connection, "action", new TestRequest(randomFrom("fail", "pass")), TransportRequestOptions.EMPTY, - transportResponseHandler); - connection.close(); + try (Transport.Connection connection = serviceB.openConnection(serviceC.getLocalNode(), builder.build())) { + serviceC.close(); + serviceB.sendRequest(connection, "action", new TestRequest("boom"), TransportRequestOptions.EMPTY, + transportResponseHandler); + } latch.await(); - serviceC.close(); + } + + public void testConcurrentDisconnectOnNonPublishedConnection() throws IOException, InterruptedException { + MockTransportService serviceC = build(Settings.builder().put("name", "TS_TEST").build(), version0, null, true); + CountDownLatch receivedLatch = new CountDownLatch(1); + CountDownLatch sendResponseLatch = new CountDownLatch(1); + serviceC.registerRequestHandler("action", TestRequest::new, ThreadPool.Names.SAME, + (request, channel) -> { + // don't block on a network thread here + threadPool.generic().execute(new AbstractRunnable() { + @Override + public void onFailure(Exception e) { + try { + channel.sendResponse(e); + } catch (IOException e1) { + throw new UncheckedIOException(e1); + } + } + + @Override + protected void doRun() throws Exception { + receivedLatch.countDown(); + sendResponseLatch.await(); + channel.sendResponse(TransportResponse.Empty.INSTANCE); + } + }); + }); + serviceC.start(); + serviceC.acceptIncomingRequests(); + CountDownLatch responseLatch = new CountDownLatch(1); + TransportResponseHandler transportResponseHandler = new TransportResponseHandler() { + @Override + public TransportResponse newInstance() { + return TransportResponse.Empty.INSTANCE; + } + + @Override + public void handleResponse(TransportResponse response) { + responseLatch.countDown(); + } + + @Override + public void handleException(TransportException exp) { + responseLatch.countDown(); + } + + @Override + public String executor() { + return ThreadPool.Names.SAME; + } + }; + + ConnectionProfile.Builder builder = new ConnectionProfile.Builder(); + builder.addConnections(1, + TransportRequestOptions.Type.BULK, + TransportRequestOptions.Type.PING, + TransportRequestOptions.Type.RECOVERY, + TransportRequestOptions.Type.REG, + TransportRequestOptions.Type.STATE); + + try (Transport.Connection connection = serviceB.openConnection(serviceC.getLocalNode(), builder.build())) { + serviceB.sendRequest(connection, "action", new TestRequest("hello world"), TransportRequestOptions.EMPTY, + transportResponseHandler); + receivedLatch.await(); + assertPendingConnections(1, serviceB.getOriginalTransport()); + serviceC.close(); + assertPendingConnections(0, serviceC.getOriginalTransport()); + sendResponseLatch.countDown(); + responseLatch.await(); + } + assertPendingConnections(0, serviceC.getOriginalTransport()); } } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java b/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java index 765d675f2da..38a1701a7e1 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java @@ -60,7 +60,6 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; @@ -178,25 +177,13 @@ public class MockTcpTransport extends TcpTransport } @Override - protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile) throws IOException { + protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile, + Consumer onChannelClose) throws IOException { final MockChannel[] mockChannels = new MockChannel[1]; - final NodeChannels nodeChannels = new NodeChannels(node, mockChannels, LIGHT_PROFILE, - transportServiceAdapter::onConnectionClosed); // we always use light here + final NodeChannels nodeChannels = new NodeChannels(node, mockChannels, LIGHT_PROFILE); // we always use light here boolean success = false; final MockSocket socket = new MockSocket(); try { - Consumer onClose = (channel) -> { - final NodeChannels connected = connectedNodes.get(node); - if (connected != null && connected.hasChannel(channel)) { - try { - executor.execute(() -> { - disconnectFromNode(node, channel, "channel closed event"); - }); - } catch (RejectedExecutionException ex) { - logger.debug("failed to run disconnectFromNode - node is shutting down"); - } - } - }; final InetSocketAddress address = node.getAddress().address(); // we just use a single connections configureSocket(socket); @@ -206,7 +193,7 @@ public class MockTcpTransport extends TcpTransport } catch (SocketTimeoutException ex) { throw new ConnectTransportException(node, "connect_timeout[" + connectTimeout + "]", ex); } - MockChannel channel = new MockChannel(socket, address, "none", onClose); + MockChannel channel = new MockChannel(socket, address, "none", onChannelClose); channel.loopRead(executor); mockChannels[0] = channel; success = true; @@ -376,7 +363,6 @@ public class MockTcpTransport extends TcpTransport synchronized (openChannels) { removedChannel = openChannels.remove(this); } - onChannelClosed(this); IOUtils.close(serverSocket, activeChannel, () -> IOUtils.close(workerChannels), () -> cancellableThreads.cancel("channel closed"), onClose); assert removedChannel: "Channel was not removed or removed twice?";