Handle race-condition when connection is closed before handshake listener was added
Today sending a message on a closed channel doesn't throw an exception. The channel might just swallow the exception and informs the internal async exception handler that a channel got disconnected. This change adds a safety check that we fail the handshake if we registered a handler but the channel has been closed already for instance due to a reset by peer.
This commit is contained in:
parent
3005366b13
commit
d27a12510b
|
@ -570,6 +570,8 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
|
||||||
}
|
}
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
logger.warn("failed to close channel", e);
|
logger.warn("failed to close channel", e);
|
||||||
|
} finally {
|
||||||
|
onChannelClosed(channel);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -1527,6 +1529,13 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
|
||||||
pendingHandshakes.put(requestId, handler);
|
pendingHandshakes.put(requestId, handler);
|
||||||
boolean success = false;
|
boolean success = false;
|
||||||
try {
|
try {
|
||||||
|
if (isOpen(channel) == false) {
|
||||||
|
// we have to protect ourself here since sendRequestToChannel won't barf if the channel is closed.
|
||||||
|
// it's weird but to change it will cause a lot of impact on the exception handling code all over the codebase.
|
||||||
|
// yet, if we don't check the state here we might have registered a pending handshake handler but the close
|
||||||
|
// listener calling #onChannelClosed might have already run and we are waiting on the latch below unitl we time out.
|
||||||
|
throw new IllegalStateException("handshake failed, channel already closed");
|
||||||
|
}
|
||||||
// for the request we use the minCompatVersion since we don't know what's the version of the node we talk to
|
// for the request we use the minCompatVersion since we don't know what's the version of the node we talk to
|
||||||
// we also have no payload on the request but the response will contain the actual version of the node we talk
|
// we also have no payload on the request but the response will contain the actual version of the node we talk
|
||||||
// to as the payload.
|
// to as the payload.
|
||||||
|
@ -1575,7 +1584,7 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Called by sub-classes for each channel that is closed
|
* Called once the channel is closed for instance due to a disconnect or a closed socket etc.
|
||||||
*/
|
*/
|
||||||
protected final void onChannelClosed(Channel channel) {
|
protected final void onChannelClosed(Channel channel) {
|
||||||
Optional<Map.Entry<Long, HandshakeResponseHandler>> first = pendingHandshakes.entrySet().stream()
|
Optional<Map.Entry<Long, HandshakeResponseHandler>> first = pendingHandshakes.entrySet().stream()
|
||||||
|
|
|
@ -38,7 +38,6 @@ import org.elasticsearch.common.util.concurrent.AbstractRunnable;
|
||||||
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
|
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
|
||||||
import org.elasticsearch.node.Node;
|
import org.elasticsearch.node.Node;
|
||||||
import org.elasticsearch.test.ESTestCase;
|
import org.elasticsearch.test.ESTestCase;
|
||||||
import org.elasticsearch.test.junit.annotations.TestLogging;
|
|
||||||
import org.elasticsearch.test.transport.MockTransportService;
|
import org.elasticsearch.test.transport.MockTransportService;
|
||||||
import org.elasticsearch.threadpool.TestThreadPool;
|
import org.elasticsearch.threadpool.TestThreadPool;
|
||||||
import org.elasticsearch.threadpool.ThreadPool;
|
import org.elasticsearch.threadpool.ThreadPool;
|
||||||
|
@ -51,6 +50,7 @@ import java.net.InetAddress;
|
||||||
import java.net.InetSocketAddress;
|
import java.net.InetSocketAddress;
|
||||||
import java.net.ServerSocket;
|
import java.net.ServerSocket;
|
||||||
import java.net.Socket;
|
import java.net.Socket;
|
||||||
|
import java.nio.channels.ClosedChannelException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -67,11 +67,13 @@ import java.util.concurrent.atomic.AtomicReference;
|
||||||
|
|
||||||
import static java.util.Collections.emptyMap;
|
import static java.util.Collections.emptyMap;
|
||||||
import static java.util.Collections.emptySet;
|
import static java.util.Collections.emptySet;
|
||||||
|
import static org.hamcrest.Matchers.anyOf;
|
||||||
import static org.hamcrest.Matchers.empty;
|
import static org.hamcrest.Matchers.empty;
|
||||||
import static org.hamcrest.Matchers.endsWith;
|
import static org.hamcrest.Matchers.endsWith;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
import static org.hamcrest.Matchers.instanceOf;
|
import static org.hamcrest.Matchers.instanceOf;
|
||||||
import static org.hamcrest.Matchers.notNullValue;
|
import static org.hamcrest.Matchers.notNullValue;
|
||||||
|
import static org.hamcrest.Matchers.startsWith;
|
||||||
|
|
||||||
public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
|
public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
|
||||||
|
|
||||||
|
@ -828,15 +830,8 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
|
||||||
assertTrue(inFlight.tryAcquire(Integer.MAX_VALUE, 10, TimeUnit.SECONDS));
|
assertTrue(inFlight.tryAcquire(Integer.MAX_VALUE, 10, TimeUnit.SECONDS));
|
||||||
}
|
}
|
||||||
|
|
||||||
@TestLogging(value = "org.elasticsearch.test.transport.tracer:TRACE")
|
|
||||||
public void testTracerLog() throws InterruptedException {
|
public void testTracerLog() throws InterruptedException {
|
||||||
TransportRequestHandler handler = new TransportRequestHandler<StringMessageRequest>() {
|
TransportRequestHandler handler = (request, channel) -> channel.sendResponse(new StringMessageResponse(""));
|
||||||
@Override
|
|
||||||
public void messageReceived(StringMessageRequest request, TransportChannel channel) throws Exception {
|
|
||||||
channel.sendResponse(new StringMessageResponse(""));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
TransportRequestHandler handlerWithError = new TransportRequestHandler<StringMessageRequest>() {
|
TransportRequestHandler handlerWithError = new TransportRequestHandler<StringMessageRequest>() {
|
||||||
@Override
|
@Override
|
||||||
public void messageReceived(StringMessageRequest request, TransportChannel channel) throws Exception {
|
public void messageReceived(StringMessageRequest request, TransportChannel channel) throws Exception {
|
||||||
|
@ -1860,9 +1855,10 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
|
||||||
Thread t = new Thread() {
|
Thread t = new Thread() {
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
try {
|
try (Socket accept = socket.accept()) {
|
||||||
Socket accept = socket.accept();
|
if (randomBoolean()) { // sometimes wait until the other side sends the message
|
||||||
accept.close();
|
accept.getInputStream().read();
|
||||||
|
}
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new UncheckedIOException(e);
|
throw new UncheckedIOException(e);
|
||||||
}
|
}
|
||||||
|
@ -1879,8 +1875,8 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
|
||||||
builder.setHandshakeTimeout(TimeValue.timeValueHours(1));
|
builder.setHandshakeTimeout(TimeValue.timeValueHours(1));
|
||||||
ConnectTransportException ex = expectThrows(ConnectTransportException.class,
|
ConnectTransportException ex = expectThrows(ConnectTransportException.class,
|
||||||
() -> serviceA.connectToNode(dummy, builder.build()));
|
() -> serviceA.connectToNode(dummy, builder.build()));
|
||||||
assertEquals("[][" + dummy.getAddress() +"] general node connection failure", ex.getMessage());
|
assertEquals(ex.getMessage(), "[][" + dummy.getAddress() +"] general node connection failure");
|
||||||
assertEquals("handshake failed", ex.getCause().getMessage());
|
assertThat(ex.getCause().getMessage(), startsWith("handshake failed"));
|
||||||
t.join();
|
t.join();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue