diff --git a/core/src/main/java/org/elasticsearch/transport/TcpTransport.java b/core/src/main/java/org/elasticsearch/transport/TcpTransport.java index 30a18e05611..11afec1bf1e 100644 --- a/core/src/main/java/org/elasticsearch/transport/TcpTransport.java +++ b/core/src/main/java/org/elasticsearch/transport/TcpTransport.java @@ -101,6 +101,7 @@ import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.locks.ReadWriteLock; import java.util.concurrent.locks.ReentrantReadWriteLock; +import java.util.function.Consumer; import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.stream.Collectors; @@ -357,8 +358,9 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i private final DiscoveryNode node; private final AtomicBoolean closed = new AtomicBoolean(false); private final Version version; + private final Consumer onClose; - public NodeChannels(DiscoveryNode node, Channel[] channels, ConnectionProfile connectionProfile) { + public NodeChannels(DiscoveryNode node, Channel[] channels, ConnectionProfile connectionProfile, Consumer onClose) { this.node = node; this.channels = channels; assert channels.length == connectionProfile.getNumConnections() : "expected channels size to be == " @@ -369,6 +371,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i typeMapping.put(type, handle); } version = node.getVersion(); + this.onClose = onClose; } NodeChannels(NodeChannels channels, Version handshakeVersion) { @@ -376,6 +379,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i this.channels = channels.channels; this.typeMapping = channels.typeMapping; this.version = handshakeVersion; + this.onClose = channels.onClose; } @Override @@ -408,6 +412,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i public synchronized void close() throws IOException { if (closed.compareAndSet(false, true)) { closeChannels(Arrays.stream(channels).filter(Objects::nonNull).collect(Collectors.toList())); + onClose.accept(this); } } @@ -519,8 +524,8 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i final TimeValue handshakeTimeout = connectionProfile.getHandshakeTimeout() == null ? connectTimeout : connectionProfile.getHandshakeTimeout(); final Version version = executeHandshake(node, channel, handshakeTimeout); - transportServiceAdapter.onConnectionOpened(node); - nodeChannels = new NodeChannels(nodeChannels, version);// clone the channels - we now have the correct version + transportServiceAdapter.onConnectionOpened(nodeChannels); + nodeChannels = new NodeChannels(nodeChannels, version); // clone the channels - we now have the correct version success = true; return nodeChannels; } catch (ConnectTransportException e) { diff --git a/core/src/main/java/org/elasticsearch/transport/Transport.java b/core/src/main/java/org/elasticsearch/transport/Transport.java index d3dcd8bb5c1..a32289332ea 100644 --- a/core/src/main/java/org/elasticsearch/transport/Transport.java +++ b/core/src/main/java/org/elasticsearch/transport/Transport.java @@ -132,5 +132,13 @@ public interface Transport extends LifecycleComponent { default Version getVersion() { return getNode().getVersion(); } + + /** + * Returns a key that this connection can be cached on. Delegating subclasses must delegate method call to + * the original connection. + */ + default Object getCacheKey() { + return this; + } } } diff --git a/core/src/main/java/org/elasticsearch/transport/TransportConnectionListener.java b/core/src/main/java/org/elasticsearch/transport/TransportConnectionListener.java index 3f277a0ee11..de767986b9f 100644 --- a/core/src/main/java/org/elasticsearch/transport/TransportConnectionListener.java +++ b/core/src/main/java/org/elasticsearch/transport/TransportConnectionListener.java @@ -33,8 +33,14 @@ public interface TransportConnectionListener { */ default void onNodeDisconnected(DiscoveryNode node) {} + /** + * Called once a node connection is closed. The connection might not have been registered in the + * transport as a shared connection to a specific node + */ + default void onConnectionClosed(Transport.Connection connection) {} + /** * Called once a node connection is opened. */ - default void onConnectionOpened(DiscoveryNode node) {} + default void onConnectionOpened(Transport.Connection connection) {} } diff --git a/core/src/main/java/org/elasticsearch/transport/TransportService.java b/core/src/main/java/org/elasticsearch/transport/TransportService.java index 7de96063615..e5382e4e261 100644 --- a/core/src/main/java/org/elasticsearch/transport/TransportService.java +++ b/core/src/main/java/org/elasticsearch/transport/TransportService.java @@ -569,7 +569,7 @@ public class TransportService extends AbstractLifecycleComponent { } Supplier storedContextSupplier = threadPool.getThreadContext().newRestorableContext(true); TransportResponseHandler responseHandler = new ContextRestoreResponseHandler<>(storedContextSupplier, handler); - clientHandlers.put(requestId, new RequestHolder<>(responseHandler, connection.getNode(), action, timeoutHandler)); + clientHandlers.put(requestId, new RequestHolder<>(responseHandler, connection, action, timeoutHandler)); if (lifecycle.stoppedOrClosed()) { // if we are not started the exception handling will remove the RequestHolder again and calls the handler to notify // the caller. It will only notify if the toStop code hasn't done the work yet. @@ -810,7 +810,7 @@ public class TransportService extends AbstractLifecycleComponent { } holder.cancelTimeout(); if (traceEnabled() && shouldTraceAction(holder.action())) { - traceReceivedResponse(requestId, holder.node(), holder.action()); + traceReceivedResponse(requestId, holder.connection().getNode(), holder.action()); } return holder.handler(); } @@ -855,12 +855,12 @@ public class TransportService extends AbstractLifecycleComponent { } @Override - public void onConnectionOpened(DiscoveryNode node) { + public void onConnectionOpened(Transport.Connection connection) { // capture listeners before spawning the background callback so the following pattern won't trigger a call // connectToNode(); connection is completed successfully // addConnectionListener(); this listener shouldn't be called final Stream listenersToNotify = TransportService.this.connectionListeners.stream(); - threadPool.generic().execute(() -> listenersToNotify.forEach(listener -> listener.onConnectionOpened(node))); + threadPool.generic().execute(() -> listenersToNotify.forEach(listener -> listener.onConnectionOpened(connection))); } @Override @@ -871,20 +871,28 @@ public class TransportService extends AbstractLifecycleComponent { connectionListener.onNodeDisconnected(node); } }); + } catch (EsRejectedExecutionException ex) { + logger.debug("Rejected execution on NodeDisconnected", ex); + } + } + + @Override + public void onConnectionClosed(Transport.Connection connection) { + try { for (Map.Entry entry : clientHandlers.entrySet()) { RequestHolder holder = entry.getValue(); - if (holder.node().equals(node)) { + if (holder.connection().getCacheKey().equals(connection.getCacheKey())) { final RequestHolder holderToNotify = clientHandlers.remove(entry.getKey()); if (holderToNotify != null) { // callback that an exception happened, but on a different thread since we don't // want handlers to worry about stack overflows - threadPool.generic().execute(() -> holderToNotify.handler().handleException(new NodeDisconnectedException(node, - holderToNotify.action()))); + threadPool.generic().execute(() -> holderToNotify.handler().handleException(new NodeDisconnectedException( + connection.getNode(), holderToNotify.action()))); } } } } catch (EsRejectedExecutionException ex) { - logger.debug("Rejected execution on NodeDisconnected", ex); + logger.debug("Rejected execution on onConnectionClosed", ex); } } @@ -929,13 +937,14 @@ public class TransportService extends AbstractLifecycleComponent { if (holder != null) { // add it to the timeout information holder, in case we are going to get a response later long timeoutTime = System.currentTimeMillis(); - timeoutInfoHandlers.put(requestId, new TimeoutInfoHolder(holder.node(), holder.action(), sentTime, timeoutTime)); + timeoutInfoHandlers.put(requestId, new TimeoutInfoHolder(holder.connection().getNode(), holder.action(), sentTime, + timeoutTime)); // now that we have the information visible via timeoutInfoHandlers, we try to remove the request id final RequestHolder removedHolder = clientHandlers.remove(requestId); if (removedHolder != null) { assert removedHolder == holder : "two different holder instances for request [" + requestId + "]"; removedHolder.handler().handleException( - new ReceiveTimeoutTransportException(holder.node(), holder.action(), + new ReceiveTimeoutTransportException(holder.connection().getNode(), holder.action(), "request_id [" + requestId + "] timed out after [" + (timeoutTime - sentTime) + "ms]")); } else { // response was processed, remove timeout info. @@ -990,15 +999,15 @@ public class TransportService extends AbstractLifecycleComponent { private final TransportResponseHandler handler; - private final DiscoveryNode node; + private final Transport.Connection connection; private final String action; private final TimeoutHandler timeoutHandler; - RequestHolder(TransportResponseHandler handler, DiscoveryNode node, String action, TimeoutHandler timeoutHandler) { + RequestHolder(TransportResponseHandler handler, Transport.Connection connection, String action, TimeoutHandler timeoutHandler) { this.handler = handler; - this.node = node; + this.connection = connection; this.action = action; this.timeoutHandler = timeoutHandler; } @@ -1007,8 +1016,8 @@ public class TransportService extends AbstractLifecycleComponent { return handler; } - public DiscoveryNode node() { - return this.node; + public Transport.Connection connection() { + return this.connection; } public String action() { diff --git a/core/src/test/java/org/elasticsearch/discovery/zen/UnicastZenPingTests.java b/core/src/test/java/org/elasticsearch/discovery/zen/UnicastZenPingTests.java index 2ed44586f73..6aa47d27bbd 100644 --- a/core/src/test/java/org/elasticsearch/discovery/zen/UnicastZenPingTests.java +++ b/core/src/test/java/org/elasticsearch/discovery/zen/UnicastZenPingTests.java @@ -604,8 +604,8 @@ public class UnicastZenPingTests extends ESTestCase { // install a listener to check that no new connections are made handleA.transportService.addConnectionListener(new TransportConnectionListener() { @Override - public void onConnectionOpened(DiscoveryNode node) { - fail("should not open any connections. got [" + node + "]"); + public void onConnectionOpened(Transport.Connection connection) { + fail("should not open any connections. got [" + connection.getNode() + "]"); } }); diff --git a/core/src/test/java/org/elasticsearch/transport/TCPTransportTests.java b/core/src/test/java/org/elasticsearch/transport/TCPTransportTests.java index c14d6ec9e05..eb9e6496521 100644 --- a/core/src/test/java/org/elasticsearch/transport/TCPTransportTests.java +++ b/core/src/test/java/org/elasticsearch/transport/TCPTransportTests.java @@ -204,7 +204,7 @@ public class TCPTransportTests extends ESTestCase { @Override protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile) throws IOException { - return new NodeChannels(node, new Object[profile.getNumConnections()], profile); + return new NodeChannels(node, new Object[profile.getNumConnections()], profile, c -> {}); } @Override @@ -220,7 +220,7 @@ public class TCPTransportTests extends ESTestCase { @Override public NodeChannels getConnection(DiscoveryNode node) { return new NodeChannels(node, new Object[MockTcpTransport.LIGHT_PROFILE.getNumConnections()], - MockTcpTransport.LIGHT_PROFILE); + MockTcpTransport.LIGHT_PROFILE, c -> {}); } }; DiscoveryNode node = new DiscoveryNode("foo", buildNewFakeTransportAddress(), Version.CURRENT); diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java index 8adeb665e04..a86bbfbe2b6 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java @@ -320,7 +320,7 @@ public class Netty4Transport extends TcpTransport { @Override protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile) { final Channel[] channels = new Channel[profile.getNumConnections()]; - final NodeChannels nodeChannels = new NodeChannels(node, channels, profile); + final NodeChannels nodeChannels = new NodeChannels(node, channels, profile, transportServiceAdapter::onConnectionClosed); boolean success = false; try { final TimeValue connectTimeout; diff --git a/test/framework/src/main/java/org/elasticsearch/test/transport/MockTransportService.java b/test/framework/src/main/java/org/elasticsearch/test/transport/MockTransportService.java index a25f435af2c..210190940d2 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/transport/MockTransportService.java +++ b/test/framework/src/main/java/org/elasticsearch/test/transport/MockTransportService.java @@ -777,6 +777,11 @@ public final class MockTransportService extends TransportService { public void close() throws IOException { connection.close(); } + + @Override + public Object getCacheKey() { + return connection.getCacheKey(); + } } public Transport getOriginalTransport() { diff --git a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java index afbacf6f63f..48d90e3ec63 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java @@ -2099,9 +2099,6 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { @Override public String executor() { - if (1 == 1) - return "same"; - return randomFrom(executors); } }; @@ -2111,4 +2108,59 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { latch.await(); } + public void testHandlerIsInvokedOnConnectionClose() throws IOException, InterruptedException { + List executors = new ArrayList<>(ThreadPool.THREAD_POOL_TYPES.keySet()); + CollectionUtil.timSort(executors); // makes sure it's reproducible + TransportService serviceC = build(Settings.builder().put("name", "TS_TEST").build(), version0, null, true); + serviceC.registerRequestHandler("action", TestRequest::new, ThreadPool.Names.SAME, + (request, channel) -> { + // do nothing + }); + serviceC.start(); + serviceC.acceptIncomingRequests(); + CountDownLatch latch = new CountDownLatch(1); + TransportResponseHandler transportResponseHandler = new TransportResponseHandler() { + @Override + public TransportResponse newInstance() { + return TransportResponse.Empty.INSTANCE; + } + + @Override + public void handleResponse(TransportResponse response) { + try { + fail("no response expected"); + } finally { + latch.countDown(); + } + } + + @Override + public void handleException(TransportException exp) { + try { + assertTrue(exp.getClass().toString(), exp instanceof NodeDisconnectedException); + } finally { + latch.countDown(); + } + } + + @Override + public String executor() { + return randomFrom(executors); + } + }; + ConnectionProfile.Builder builder = new ConnectionProfile.Builder(); + builder.addConnections(1, + TransportRequestOptions.Type.BULK, + TransportRequestOptions.Type.PING, + TransportRequestOptions.Type.RECOVERY, + TransportRequestOptions.Type.REG, + TransportRequestOptions.Type.STATE); + Transport.Connection connection = serviceB.openConnection(serviceC.getLocalNode(), builder.build()); + serviceB.sendRequest(connection, "action", new TestRequest(randomFrom("fail", "pass")), TransportRequestOptions.EMPTY, + transportResponseHandler); + connection.close(); + latch.await(); + serviceC.close(); + } + } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java b/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java index f9e5ff8981e..765d675f2da 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java @@ -180,7 +180,8 @@ public class MockTcpTransport extends TcpTransport @Override protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile) throws IOException { final MockChannel[] mockChannels = new MockChannel[1]; - final NodeChannels nodeChannels = new NodeChannels(node, mockChannels, LIGHT_PROFILE); // we always use light here + final NodeChannels nodeChannels = new NodeChannels(node, mockChannels, LIGHT_PROFILE, + transportServiceAdapter::onConnectionClosed); // we always use light here boolean success = false; final MockSocket socket = new MockSocket(); try {