diff --git a/core/src/main/java/org/elasticsearch/common/transport/NetworkExceptionHelper.java b/core/src/main/java/org/elasticsearch/common/transport/NetworkExceptionHelper.java index 77a39a8c22b..0317026b6be 100644 --- a/core/src/main/java/org/elasticsearch/common/transport/NetworkExceptionHelper.java +++ b/core/src/main/java/org/elasticsearch/common/transport/NetworkExceptionHelper.java @@ -55,6 +55,9 @@ public class NetworkExceptionHelper { if (e.getMessage().contains("Connection timed out")) { return true; } + if (e.getMessage().equals("Socket is closed")) { + return true; + } } return false; } diff --git a/core/src/main/java/org/elasticsearch/transport/TcpTransport.java b/core/src/main/java/org/elasticsearch/transport/TcpTransport.java index b0139902a42..0e601ecb5b5 100644 --- a/core/src/main/java/org/elasticsearch/transport/TcpTransport.java +++ b/core/src/main/java/org/elasticsearch/transport/TcpTransport.java @@ -918,19 +918,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i transportServiceAdapter.onRequestSent(node, requestId, action, request, finalOptions); } }; - try { - sendMessage(targetChannel, message, onRequestSent, false); - } catch (IOException ex) { - if (nodeConnected(node)) { - throw ex; - } else { - // we might got disconnected in between the nodeChannel(node, options) call and the sending - - // in that case throw a subclass of ConnectTransportException since some code retries based on this - // see TransportMasterNodeAction for instance - throw new NodeNotConnectedException(node, "Node not connected"); - } - } - addedReleaseListener = true; + addedReleaseListener = internalSendMessage(targetChannel, message, onRequestSent); } finally { IOUtils.close(stream); if (!addedReleaseListener) { @@ -939,6 +927,25 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i } } + /** + * sends a message view the given channel, using the given callbacks. + * + * @return true if the message was successfully sent or false when an error occurred and the error hanlding logic was activated + * + */ + private boolean internalSendMessage(Channel targetChannel, BytesReference message, Runnable onRequestSent) throws IOException { + boolean success; + try { + sendMessage(targetChannel, message, onRequestSent, false); + success = true; + } catch (IOException ex) { + // passing exception handling to deal with this and raise disconnect events and decide the right logging level + onException(targetChannel, ex); + success = false; + } + return success; + } + /** * Sends back an error response to the caller via the given channel * @param nodeVersion the caller node version @@ -997,9 +1004,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i transportServiceAdapter.onResponseSent(requestId, action, response, finalOptions); } }; - sendMessage(channel, reference, onRequestSent, false); - addedReleaseListener = true; - + addedReleaseListener = internalSendMessage(channel, reference, onRequestSent); } finally { IOUtils.close(stream); if (!addedReleaseListener) { diff --git a/core/src/main/java/org/elasticsearch/transport/local/LocalTransport.java b/core/src/main/java/org/elasticsearch/transport/local/LocalTransport.java index eba5fd57734..c94e62ea422 100644 --- a/core/src/main/java/org/elasticsearch/transport/local/LocalTransport.java +++ b/core/src/main/java/org/elasticsearch/transport/local/LocalTransport.java @@ -38,6 +38,7 @@ import org.elasticsearch.common.transport.LocalTransportAddress; import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.threadpool.ThreadPool; @@ -230,12 +231,30 @@ public class LocalTransport extends AbstractLifecycleComponent implements Transp final byte[] data = BytesReference.toBytes(stream.bytes()); transportServiceAdapter.sent(data.length); transportServiceAdapter.onRequestSent(node, requestId, action, request, options); - targetTransport.workers().execute(() -> { - ThreadContext threadContext = targetTransport.threadPool.getThreadContext(); + targetTransport.receiveMessage(version, data, action, requestId, this); + } + } + + /** + * entry point for incoming messages + * + * @param version the version used to serialize the message + * @param data message data + * @param action the action associated with this message (only used for error handling when data is not parsable) + * @param requestId requestId if the message is request (only used for error handling when data is not parsable) + * @param sourceTransport the source transport to respond to. + */ + public void receiveMessage(Version version, byte[] data, String action, @Nullable Long requestId, LocalTransport sourceTransport) { + try { + workers().execute(() -> { + ThreadContext threadContext = threadPool.getThreadContext(); try (ThreadContext.StoredContext context = threadContext.stashContext()) { - targetTransport.messageReceived(data, action, LocalTransport.this, version, requestId); + processReceivedMessage(data, action, sourceTransport, version, requestId); } }); + } catch (EsRejectedExecutionException e) { + assert lifecycle.started() == false; + logger.trace("received request but shutting down. ignoring. action [{}], request id [{}]", action, requestId); } } @@ -248,8 +267,9 @@ public class LocalTransport extends AbstractLifecycleComponent implements Transp return circuitBreakerService.getBreaker(CircuitBreaker.IN_FLIGHT_REQUESTS); } - protected void messageReceived(byte[] data, String action, LocalTransport sourceTransport, Version version, - @Nullable final Long sendRequestId) { + /** processes received messages, assuming thread passing and thread context have all been dealt with */ + protected void processReceivedMessage(byte[] data, String action, LocalTransport sourceTransport, Version version, + @Nullable final Long sendRequestId) { Transports.assertTransportThread(); try { transportServiceAdapter.received(data.length); diff --git a/core/src/main/java/org/elasticsearch/transport/local/LocalTransportChannel.java b/core/src/main/java/org/elasticsearch/transport/local/LocalTransportChannel.java index fc748b96aea..0c1e8747a12 100644 --- a/core/src/main/java/org/elasticsearch/transport/local/LocalTransportChannel.java +++ b/core/src/main/java/org/elasticsearch/transport/local/LocalTransportChannel.java @@ -107,12 +107,7 @@ public class LocalTransportChannel implements TransportChannel { private void sendResponseData(byte[] data) { close(); - targetTransport.workers().execute(() -> { - ThreadContext threadContext = targetTransport.threadPool.getThreadContext(); - try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { - targetTransport.messageReceived(data, action, sourceTransport, version, null); - } - }); + targetTransport.receiveMessage(version, data, action, null, sourceTransport); } private void close() { diff --git a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java index 42275c75e5a..87087e772ab 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java @@ -21,6 +21,8 @@ package org.elasticsearch.transport; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListenerResponseHandler; +import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -28,6 +30,8 @@ import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.concurrent.AbstractRunnable; +import org.elasticsearch.common.util.concurrent.ConcurrentCollections; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.junit.annotations.TestLogging; import org.elasticsearch.test.transport.MockTransportService; @@ -41,16 +45,19 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.concurrent.BrokenBarrierException; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.CyclicBarrier; import java.util.concurrent.ExecutionException; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import static java.util.Collections.emptyMap; import static java.util.Collections.emptySet; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.endsWith; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; @@ -75,24 +82,10 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { public void setUp() throws Exception { super.setUp(); threadPool = new TestThreadPool(getClass().getName()); - serviceA = build( - Settings.builder() - .put("name", "TS_A") - .put(TransportService.TRACE_LOG_INCLUDE_SETTING.getKey(), "") - .put(TransportService.TRACE_LOG_EXCLUDE_SETTING.getKey(), "NOTHING") - .build(), - version0); - serviceA.acceptIncomingRequests(); + serviceA = buildService("TS_A", version0); nodeA = new DiscoveryNode("TS_A", serviceA.boundAddress().publishAddress(), emptyMap(), emptySet(), version0); // serviceA.setLocalNode(nodeA); - serviceB = build( - Settings.builder() - .put("name", "TS_B") - .put(TransportService.TRACE_LOG_INCLUDE_SETTING.getKey(), "") - .put(TransportService.TRACE_LOG_EXCLUDE_SETTING.getKey(), "NOTHING") - .build(), - version1); - serviceB.acceptIncomingRequests(); + serviceB = buildService("TS_B", version1); nodeB = new DiscoveryNode("TS_B", serviceB.boundAddress().publishAddress(), emptyMap(), emptySet(), version1); //serviceB.setLocalNode(nodeB); // wait till all nodes are properly connected and the event has been sent, so tests in this class @@ -131,6 +124,18 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { serviceB.removeConnectionListener(waitForConnection); } + private MockTransportService buildService(final String name, final Version version) { + MockTransportService service = build( + Settings.builder() + .put("name", name) + .put(TransportService.TRACE_LOG_INCLUDE_SETTING.getKey(), "") + .put(TransportService.TRACE_LOG_EXCLUDE_SETTING.getKey(), "NOTHING") + .build(), + version); + service.acceptIncomingRequests(); + return service; + } + @Override @After public void tearDown() throws Exception { @@ -483,6 +488,122 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { assertThat(latch.await(5, TimeUnit.SECONDS), equalTo(true)); } + public void testConcurrentSendRespondAndDisconnect() throws BrokenBarrierException, InterruptedException { + Set sendingErrors = ConcurrentCollections.newConcurrentSet(); + Set responseErrors = ConcurrentCollections.newConcurrentSet(); + serviceA.registerRequestHandler("test", TestRequest::new, + randomBoolean() ? ThreadPool.Names.SAME : ThreadPool.Names.GENERIC, (request, channel) -> { + try { + channel.sendResponse(new TestResponse()); + } catch (Exception e) { + logger.info("caught exception while responding", e); + responseErrors.add(e); + } + }); + final TransportRequestHandler ignoringRequestHandler = (request, channel) -> { + try { + channel.sendResponse(new TestResponse()); + } catch (Exception e) { + // we don't really care what's going on B, we're testing through A + logger.trace("caught exception while res ponding from node B", e); + } + }; + serviceB.registerRequestHandler("test", TestRequest::new, ThreadPool.Names.SAME, ignoringRequestHandler); + + int halfSenders = scaledRandomIntBetween(3, 10); + final CyclicBarrier go = new CyclicBarrier(halfSenders * 2 + 1); + final CountDownLatch done = new CountDownLatch(halfSenders * 2); + for (int i = 0; i < halfSenders; i++) { + // B senders just generated activity so serciveA can respond, we don't test what's going on there + final int sender = i; + threadPool.executor(ThreadPool.Names.GENERIC).execute(new AbstractRunnable() { + @Override + public void onFailure(Exception e) { + logger.trace("caught exception while sending from B", e); + } + + @Override + protected void doRun() throws Exception { + go.await(); + for (int iter = 0; iter < 10; iter++) { + PlainActionFuture listener = new PlainActionFuture<>(); + final String info = sender + "_B_" + iter; + serviceB.sendRequest(nodeA, "test", new TestRequest(info), + new ActionListenerResponseHandler<>(listener, TestResponse::new)); + try { + listener.actionGet(); + + } catch (Exception e) { + logger.trace("caught exception while sending to node {}", e, nodeA); + } + } + } + + @Override + public void onAfter() { + done.countDown(); + } + }); + } + + for (int i = 0; i < halfSenders; i++) { + final int sender = i; + threadPool.executor(ThreadPool.Names.GENERIC).execute(new AbstractRunnable() { + @Override + public void onFailure(Exception e) { + logger.error("unexpected error", e); + sendingErrors.add(e); + } + + @Override + protected void doRun() throws Exception { + go.await(); + for (int iter = 0; iter < 10; iter++) { + PlainActionFuture listener = new PlainActionFuture<>(); + final String info = sender + "_" + iter; + serviceA.sendRequest(nodeB, "test", new TestRequest(info), + new ActionListenerResponseHandler<>(listener, TestResponse::new)); + try { + listener.actionGet(); + } catch (ConnectTransportException e) { + // ok! + } catch (Exception e) { + logger.error("caught exception while sending to node {}", e, nodeB); + sendingErrors.add(e); + } + } + } + + @Override + public void onAfter() { + done.countDown(); + } + }); + } + go.await(); + for (int i = 0; i <= 10; i++) { + if (i % 3 == 0) { + // simulate restart of nodeB + serviceB.close(); + MockTransportService newService = buildService("TS_B", version1); + newService.registerRequestHandler("test", TestRequest::new, ThreadPool.Names.SAME, ignoringRequestHandler); + serviceB = newService; + nodeB = new DiscoveryNode("TS_B", serviceB.boundAddress().publishAddress(), emptyMap(), emptySet(), version1); + serviceB.connectToNode(nodeA); + serviceA.connectToNode(nodeB); + } else if (serviceA.nodeConnected(nodeB)) { + serviceA.disconnectFromNode(nodeB); + } else { + serviceA.connectToNode(nodeB); + } + } + + done.await(); + + assertThat("found non connection errors while sending", sendingErrors, empty()); + assertThat("found non connection errors while responding", responseErrors, empty()); + } + public void testNotifyOnShutdown() throws Exception { final CountDownLatch latch2 = new CountDownLatch(1);