Handle race-condition when connection is closed before handshake listener was added

Today sending a message on a closed channel doesn't throw an exception. The channel
might just swallow the exception and informs the internal async exception handler
that a channel got disconnected. This change adds a safety check that we fail
the handshake if we registered a handler but the channel has been closed already
for instance due to a reset by peer.
This commit is contained in:
Simon Willnauer 2016-12-15 12:41:50 +01:00
parent 3005366b13
commit d27a12510b
2 changed files with 20 additions and 15 deletions

View File

@ -570,6 +570,8 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
}
} catch (IOException e) {
logger.warn("failed to close channel", e);
} finally {
onChannelClosed(channel);
}
});
}
@ -1527,6 +1529,13 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
pendingHandshakes.put(requestId, handler);
boolean success = false;
try {
if (isOpen(channel) == false) {
// we have to protect ourself here since sendRequestToChannel won't barf if the channel is closed.
// it's weird but to change it will cause a lot of impact on the exception handling code all over the codebase.
// yet, if we don't check the state here we might have registered a pending handshake handler but the close
// listener calling #onChannelClosed might have already run and we are waiting on the latch below unitl we time out.
throw new IllegalStateException("handshake failed, channel already closed");
}
// for the request we use the minCompatVersion since we don't know what's the version of the node we talk to
// we also have no payload on the request but the response will contain the actual version of the node we talk
// to as the payload.
@ -1575,7 +1584,7 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
}
/**
* Called by sub-classes for each channel that is closed
* Called once the channel is closed for instance due to a disconnect or a closed socket etc.
*/
protected final void onChannelClosed(Channel channel) {
Optional<Map.Entry<Long, HandshakeResponseHandler>> first = pendingHandshakes.entrySet().stream()

View File

@ -38,7 +38,6 @@ import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.node.Node;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.junit.annotations.TestLogging;
import org.elasticsearch.test.transport.MockTransportService;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
@ -51,6 +50,7 @@ import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.nio.channels.ClosedChannelException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
@ -67,11 +67,13 @@ import java.util.concurrent.atomic.AtomicReference;
import static java.util.Collections.emptyMap;
import static java.util.Collections.emptySet;
import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.endsWith;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.startsWith;
public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
@ -828,15 +830,8 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
assertTrue(inFlight.tryAcquire(Integer.MAX_VALUE, 10, TimeUnit.SECONDS));
}
@TestLogging(value = "org.elasticsearch.test.transport.tracer:TRACE")
public void testTracerLog() throws InterruptedException {
TransportRequestHandler handler = new TransportRequestHandler<StringMessageRequest>() {
@Override
public void messageReceived(StringMessageRequest request, TransportChannel channel) throws Exception {
channel.sendResponse(new StringMessageResponse(""));
}
};
TransportRequestHandler handler = (request, channel) -> channel.sendResponse(new StringMessageResponse(""));
TransportRequestHandler handlerWithError = new TransportRequestHandler<StringMessageRequest>() {
@Override
public void messageReceived(StringMessageRequest request, TransportChannel channel) throws Exception {
@ -1860,9 +1855,10 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
Thread t = new Thread() {
@Override
public void run() {
try {
Socket accept = socket.accept();
accept.close();
try (Socket accept = socket.accept()) {
if (randomBoolean()) { // sometimes wait until the other side sends the message
accept.getInputStream().read();
}
} catch (IOException e) {
throw new UncheckedIOException(e);
}
@ -1879,8 +1875,8 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
builder.setHandshakeTimeout(TimeValue.timeValueHours(1));
ConnectTransportException ex = expectThrows(ConnectTransportException.class,
() -> serviceA.connectToNode(dummy, builder.build()));
assertEquals("[][" + dummy.getAddress() +"] general node connection failure", ex.getMessage());
assertEquals("handshake failed", ex.getCause().getMessage());
assertEquals(ex.getMessage(), "[][" + dummy.getAddress() +"] general node connection failure");
assertThat(ex.getCause().getMessage(), startsWith("handshake failed"));
t.join();
}
}