Handle race-condition when connection is closed before handshake listener was added

Today sending a message on a closed channel doesn't throw an exception. The channel
might just swallow the exception and informs the internal async exception handler
that a channel got disconnected. This change adds a safety check that we fail
the handshake if we registered a handler but the channel has been closed already
for instance due to a reset by peer.
This commit is contained in:
Simon Willnauer 2016-12-15 12:41:50 +01:00
parent 3005366b13
commit d27a12510b
2 changed files with 20 additions and 15 deletions

View File

@ -570,6 +570,8 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
} }
} catch (IOException e) { } catch (IOException e) {
logger.warn("failed to close channel", e); logger.warn("failed to close channel", e);
} finally {
onChannelClosed(channel);
} }
}); });
} }
@ -1527,6 +1529,13 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
pendingHandshakes.put(requestId, handler); pendingHandshakes.put(requestId, handler);
boolean success = false; boolean success = false;
try { try {
if (isOpen(channel) == false) {
// we have to protect ourself here since sendRequestToChannel won't barf if the channel is closed.
// it's weird but to change it will cause a lot of impact on the exception handling code all over the codebase.
// yet, if we don't check the state here we might have registered a pending handshake handler but the close
// listener calling #onChannelClosed might have already run and we are waiting on the latch below unitl we time out.
throw new IllegalStateException("handshake failed, channel already closed");
}
// for the request we use the minCompatVersion since we don't know what's the version of the node we talk to // for the request we use the minCompatVersion since we don't know what's the version of the node we talk to
// we also have no payload on the request but the response will contain the actual version of the node we talk // we also have no payload on the request but the response will contain the actual version of the node we talk
// to as the payload. // to as the payload.
@ -1575,7 +1584,7 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
} }
/** /**
* Called by sub-classes for each channel that is closed * Called once the channel is closed for instance due to a disconnect or a closed socket etc.
*/ */
protected final void onChannelClosed(Channel channel) { protected final void onChannelClosed(Channel channel) {
Optional<Map.Entry<Long, HandshakeResponseHandler>> first = pendingHandshakes.entrySet().stream() Optional<Map.Entry<Long, HandshakeResponseHandler>> first = pendingHandshakes.entrySet().stream()

View File

@ -38,7 +38,6 @@ import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections; import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.node.Node; import org.elasticsearch.node.Node;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.junit.annotations.TestLogging;
import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.test.transport.MockTransportService;
import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
@ -51,6 +50,7 @@ import java.net.InetAddress;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.ServerSocket; import java.net.ServerSocket;
import java.net.Socket; import java.net.Socket;
import java.nio.channels.ClosedChannelException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
@ -67,11 +67,13 @@ import java.util.concurrent.atomic.AtomicReference;
import static java.util.Collections.emptyMap; import static java.util.Collections.emptyMap;
import static java.util.Collections.emptySet; import static java.util.Collections.emptySet;
import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.endsWith; import static org.hamcrest.Matchers.endsWith;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.startsWith;
public abstract class AbstractSimpleTransportTestCase extends ESTestCase { public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
@ -828,15 +830,8 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
assertTrue(inFlight.tryAcquire(Integer.MAX_VALUE, 10, TimeUnit.SECONDS)); assertTrue(inFlight.tryAcquire(Integer.MAX_VALUE, 10, TimeUnit.SECONDS));
} }
@TestLogging(value = "org.elasticsearch.test.transport.tracer:TRACE")
public void testTracerLog() throws InterruptedException { public void testTracerLog() throws InterruptedException {
TransportRequestHandler handler = new TransportRequestHandler<StringMessageRequest>() { TransportRequestHandler handler = (request, channel) -> channel.sendResponse(new StringMessageResponse(""));
@Override
public void messageReceived(StringMessageRequest request, TransportChannel channel) throws Exception {
channel.sendResponse(new StringMessageResponse(""));
}
};
TransportRequestHandler handlerWithError = new TransportRequestHandler<StringMessageRequest>() { TransportRequestHandler handlerWithError = new TransportRequestHandler<StringMessageRequest>() {
@Override @Override
public void messageReceived(StringMessageRequest request, TransportChannel channel) throws Exception { public void messageReceived(StringMessageRequest request, TransportChannel channel) throws Exception {
@ -1860,9 +1855,10 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
Thread t = new Thread() { Thread t = new Thread() {
@Override @Override
public void run() { public void run() {
try { try (Socket accept = socket.accept()) {
Socket accept = socket.accept(); if (randomBoolean()) { // sometimes wait until the other side sends the message
accept.close(); accept.getInputStream().read();
}
} catch (IOException e) { } catch (IOException e) {
throw new UncheckedIOException(e); throw new UncheckedIOException(e);
} }
@ -1879,8 +1875,8 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
builder.setHandshakeTimeout(TimeValue.timeValueHours(1)); builder.setHandshakeTimeout(TimeValue.timeValueHours(1));
ConnectTransportException ex = expectThrows(ConnectTransportException.class, ConnectTransportException ex = expectThrows(ConnectTransportException.class,
() -> serviceA.connectToNode(dummy, builder.build())); () -> serviceA.connectToNode(dummy, builder.build()));
assertEquals("[][" + dummy.getAddress() +"] general node connection failure", ex.getMessage()); assertEquals(ex.getMessage(), "[][" + dummy.getAddress() +"] general node connection failure");
assertEquals("handshake failed", ex.getCause().getMessage()); assertThat(ex.getCause().getMessage(), startsWith("handshake failed"));
t.join(); t.join();
} }
} }