Isolate nio channel registered from channel active (#44388)

Registering a channel with a selector is a required operation for the
channel to be handled properly. Currently, we mix the registeration with
other setup operations (ip filtering, SSL initiation, etc). However, a
fail to register is fatal. This PR modifies how registeration occurs to
immediately close the channel if it fails.

There are still two clear loopholes for how a user can interact with a
channel even if registration fails. 1. through the exception handler.
2. through the channel accepted callback. These can perhaps be improved
in the future. For now, this PR prevents writes from proceeding if the
channel is not registered.
This commit is contained in:
Tim Brooks 2019-07-16 18:46:41 -04:00
parent cc0ff3aa71
commit 0a352486e8
No known key found for this signature in database
GPG Key ID: C2AA3BB91A889E77
18 changed files with 201 additions and 110 deletions

View File

@ -35,7 +35,7 @@ public abstract class BytesWriteHandler implements NioChannelHandler {
}
@Override
public void channelRegistered() {}
public void channelActive() {}
@Override
public List<FlushOperation> writeToBytes(WriteOperation writeOperation) {

View File

@ -50,17 +50,19 @@ public abstract class ChannelContext<S extends SelectableChannel & NetworkChanne
doSelectorRegister();
}
protected void channelActive() throws IOException {}
// Package private for testing
void doSelectorRegister() throws IOException {
setSelectionKey(rawChannel.register(getSelector().rawSelector(), 0));
setSelectionKey(rawChannel.register(getSelector().rawSelector(), 0, this));
}
SelectionKey getSelectionKey() {
protected SelectionKey getSelectionKey() {
return selectionKey;
}
// Protected for tests
protected void setSelectionKey(SelectionKey selectionKey) {
// public for tests
public void setSelectionKey(SelectionKey selectionKey) {
this.selectionKey = selectionKey;
}

View File

@ -32,8 +32,8 @@ public abstract class DelegatingHandler implements NioChannelHandler {
}
@Override
public void channelRegistered() {
this.delegate.channelRegistered();
public void channelActive() {
this.delegate.channelActive();
}
@Override

View File

@ -63,8 +63,28 @@ public class EventHandler {
*/
protected void handleRegistration(ChannelContext<?> context) throws IOException {
context.register();
SelectionKey selectionKey = context.getSelectionKey();
selectionKey.attach(context);
assert context.getSelectionKey() != null : "SelectionKey should not be null after registration";
assert context.getSelectionKey().attachment() != null : "Attachment should not be null after registration";
}
/**
* This method is called when an attempt to register a channel throws an exception.
*
* @param context that was registered
* @param exception that occurred
*/
protected void registrationException(ChannelContext<?> context, Exception exception) {
context.handleException(exception);
}
/**
* This method is called after a NioChannel is active with the selector. It should only be called once
* per channel.
*
* @param context that was marked active
*/
protected void handleActive(ChannelContext<?> context) throws IOException {
context.channelActive();
if (context instanceof SocketChannelContext) {
if (((SocketChannelContext) context).readyForFlush()) {
SelectionKeyUtils.setConnectReadAndWriteInterested(context.getSelectionKey());
@ -78,12 +98,12 @@ public class EventHandler {
}
/**
* This method is called when an attempt to register a channel throws an exception.
* This method is called when setting a channel to active throws an exception.
*
* @param context that was registered
* @param context that was marked active
* @param exception that occurred
*/
protected void registrationException(ChannelContext<?> context, Exception exception) {
protected void activeException(ChannelContext<?> context, Exception exception) {
context.handleException(exception);
}
@ -180,15 +200,9 @@ public class EventHandler {
closeException(context, e);
}
} else {
boolean pendingWrites = context.readyForFlush();
SelectionKey selectionKey = context.getSelectionKey();
if (selectionKey == null) {
if (pendingWrites) {
writeException(context, new IllegalStateException("Tried to write to an not yet registered channel"));
}
return;
}
boolean currentlyWriteInterested = SelectionKeyUtils.isWriteInterested(selectionKey);
boolean pendingWrites = context.readyForFlush();
if (currentlyWriteInterested == false && pendingWrites) {
SelectionKeyUtils.setWriteInterested(selectionKey);
} else if (currentlyWriteInterested && pendingWrites == false) {

View File

@ -29,9 +29,9 @@ import java.util.function.BiConsumer;
public interface NioChannelHandler {
/**
* This method is called when the channel is registered with its selector.
* This method is called when the channel is active for use.
*/
void channelRegistered();
void channelActive();
/**
* This method is called when a message is queued with a channel. It can be called from any thread.

View File

@ -340,6 +340,14 @@ public class NioSelector implements Closeable {
private void writeToChannel(WriteOperation writeOperation) {
assertOnSelectorThread();
SocketChannelContext context = writeOperation.getChannel();
if (context.isOpen() == false) {
executeFailedListener(writeOperation.getListener(), new ClosedChannelException());
} else if (context.getSelectionKey() == null) {
// This should very rarely happen. The only times a channel is exposed outside the event loop,
// but might not registered is through the exception handler and channel accepted callbacks.
executeFailedListener(writeOperation.getListener(), new IllegalStateException("Channel not registered"));
} else {
// If the channel does not currently have anything that is ready to flush, we should flush after
// the write operation is queued.
boolean shouldFlushAfterQueuing = context.readyForFlush() == false;
@ -357,6 +365,7 @@ public class NioSelector implements Closeable {
eventHandler.postHandling(context);
}
}
}
/**
* Executes a success listener with consistent exception handling. This can only be called from current
@ -435,14 +444,25 @@ public class NioSelector implements Closeable {
try {
if (newChannel.isOpen()) {
eventHandler.handleRegistration(newChannel);
channelActive(newChannel);
if (newChannel instanceof SocketChannelContext) {
attemptConnect((SocketChannelContext) newChannel, false);
}
} else {
eventHandler.registrationException(newChannel, new ClosedChannelException());
closeChannel(newChannel);
}
} catch (Exception e) {
eventHandler.registrationException(newChannel, e);
closeChannel(newChannel);
}
}
private void channelActive(ChannelContext<?> newChannel) {
try {
eventHandler.handleActive(newChannel);
} catch (IOException e) {
eventHandler.activeException(newChannel, e);
}
}
@ -464,11 +484,7 @@ public class NioSelector implements Closeable {
private void handleQueuedWrites() {
WriteOperation writeOperation;
while ((writeOperation = queuedWrites.poll()) != null) {
if (writeOperation.getChannel().isOpen()) {
writeToChannel(writeOperation);
} else {
executeFailedListener(writeOperation.getListener(), new ClosedChannelException());
}
}
}

View File

@ -156,9 +156,8 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
}
@Override
protected void register() throws IOException {
super.register();
readWriteHandler.channelRegistered();
protected void channelActive() throws IOException {
readWriteHandler.channelActive();
}
@Override

View File

@ -81,32 +81,25 @@ public class EventHandlerTests extends ESTestCase {
}
public void testRegisterCallsContext() throws IOException {
NioSocketChannel channel = mock(NioSocketChannel.class);
SocketChannelContext channelContext = mock(SocketChannelContext.class);
when(channel.getContext()).thenReturn(channelContext);
when(channelContext.getSelectionKey()).thenReturn(new TestSelectionKey(0));
ChannelContext<?> channelContext = randomBoolean() ? mock(SocketChannelContext.class) : mock(ServerChannelContext.class);
TestSelectionKey attachment = new TestSelectionKey(0);
when(channelContext.getSelectionKey()).thenReturn(attachment);
attachment.attach(channelContext);
handler.handleRegistration(channelContext);
verify(channelContext).register();
}
public void testRegisterNonServerAddsOP_CONNECTAndOP_READInterest() throws IOException {
public void testActiveNonServerAddsOP_CONNECTAndOP_READInterest() throws IOException {
SocketChannelContext context = mock(SocketChannelContext.class);
when(context.getSelectionKey()).thenReturn(new TestSelectionKey(0));
handler.handleRegistration(context);
handler.handleActive(context);
assertEquals(SelectionKey.OP_READ | SelectionKey.OP_CONNECT, context.getSelectionKey().interestOps());
}
public void testRegisterAddsAttachment() throws IOException {
ChannelContext<?> context = randomBoolean() ? mock(SocketChannelContext.class) : mock(ServerChannelContext.class);
when(context.getSelectionKey()).thenReturn(new TestSelectionKey(0));
handler.handleRegistration(context);
assertEquals(context, context.getSelectionKey().attachment());
}
public void testHandleServerRegisterSetsOP_ACCEPTInterest() throws IOException {
assertNull(serverContext.getSelectionKey());
handler.handleRegistration(serverContext);
public void testHandleServerActiveSetsOP_ACCEPTInterest() throws IOException {
ServerChannelContext serverContext = mock(ServerChannelContext.class);
when(serverContext.getSelectionKey()).thenReturn(new TestSelectionKey(0));
handler.handleActive(serverContext);
assertEquals(SelectionKey.OP_ACCEPT, serverContext.getSelectionKey().interestOps());
}
@ -141,11 +134,11 @@ public class EventHandlerTests extends ESTestCase {
verify(serverChannelContext).handleException(exception);
}
public void testRegisterWithPendingWritesAddsOP_CONNECTAndOP_READAndOP_WRITEInterest() throws IOException {
public void testActiveWithPendingWritesAddsOP_CONNECTAndOP_READAndOP_WRITEInterest() throws IOException {
FlushReadyWrite flushReadyWrite = mock(FlushReadyWrite.class);
when(readWriteHandler.writeToBytes(flushReadyWrite)).thenReturn(Collections.singletonList(flushReadyWrite));
context.queueWriteOperation(flushReadyWrite);
handler.handleRegistration(context);
handler.handleActive(context);
assertEquals(SelectionKey.OP_READ | SelectionKey.OP_CONNECT | SelectionKey.OP_WRITE, context.getSelectionKey().interestOps());
}
@ -266,7 +259,9 @@ public class EventHandlerTests extends ESTestCase {
@Override
public void register() {
setSelectionKey(new TestSelectionKey(0));
TestSelectionKey selectionKey = new TestSelectionKey(0);
setSelectionKey(selectionKey);
selectionKey.attach(this);
}
}
@ -280,7 +275,9 @@ public class EventHandlerTests extends ESTestCase {
@Override
public void register() {
TestSelectionKey selectionKey = new TestSelectionKey(0);
setSelectionKey(new TestSelectionKey(0));
selectionKey.attach(this);
}
}
}

View File

@ -212,6 +212,7 @@ public class NioSelectorTests extends ESTestCase {
selector.preSelect();
verify(eventHandler).handleRegistration(serverChannelContext);
verify(eventHandler).handleActive(serverChannelContext);
}
public void testClosedServerChannelWillNotBeRegistered() {
@ -230,7 +231,20 @@ public class NioSelectorTests extends ESTestCase {
selector.preSelect();
verify(eventHandler, times(0)).handleActive(serverChannelContext);
verify(eventHandler).registrationException(serverChannelContext, closedChannelException);
verify(eventHandler).handleClose(serverChannelContext);
}
public void testChannelActiveException() throws Exception {
executeOnNewThread(() -> selector.scheduleForRegistration(serverChannel));
IOException ioException = new IOException();
doThrow(ioException).when(eventHandler).handleActive(serverChannelContext);
selector.preSelect();
verify(eventHandler).handleActive(serverChannelContext);
verify(eventHandler).activeException(serverChannelContext, ioException);
}
public void testClosedSocketChannelWillNotBeRegistered() throws Exception {
@ -241,6 +255,7 @@ public class NioSelectorTests extends ESTestCase {
verify(eventHandler).registrationException(same(channelContext), any(ClosedChannelException.class));
verify(eventHandler, times(0)).handleConnect(channelContext);
verify(eventHandler).handleClose(channelContext);
}
public void testRegisterSocketChannelFailsDueToException() throws InterruptedException {
@ -253,7 +268,9 @@ public class NioSelectorTests extends ESTestCase {
selector.preSelect();
verify(eventHandler).registrationException(channelContext, closedChannelException);
verify(eventHandler, times(0)).handleActive(serverChannelContext);
verify(eventHandler, times(0)).handleConnect(channelContext);
verify(eventHandler).handleClose(channelContext);
});
}
@ -313,6 +330,17 @@ public class NioSelectorTests extends ESTestCase {
verify(listener).accept(isNull(Void.class), any(ClosedChannelException.class));
}
public void testQueueWriteChannelIsUnregistered() throws Exception {
WriteOperation writeOperation = new FlushReadyWrite(channelContext, buffers, listener);
executeOnNewThread(() -> selector.queueWrite(writeOperation));
when(channelContext.getSelectionKey()).thenReturn(null);
selector.preSelect();
verify(channelContext, times(0)).queueWriteOperation(writeOperation);
verify(listener).accept(isNull(Void.class), any(IllegalStateException.class));
}
public void testQueueWriteSuccessful() throws Exception {
WriteOperation writeOperation = new FlushReadyWrite(channelContext, buffers, listener);
executeOnNewThread(() -> selector.queueWrite(writeOperation));

View File

@ -53,7 +53,7 @@ public class SocketChannelContextTests extends ESTestCase {
private NioSocketChannel channel;
private BiConsumer<Void, Exception> listener;
private NioSelector selector;
private NioChannelHandler readWriteHandler;
private NioChannelHandler handler;
private ByteBuffer ioBuffer = ByteBuffer.allocate(1024);
@SuppressWarnings("unchecked")
@ -67,9 +67,9 @@ public class SocketChannelContextTests extends ESTestCase {
when(channel.getRawChannel()).thenReturn(rawChannel);
exceptionHandler = mock(Consumer.class);
selector = mock(NioSelector.class);
readWriteHandler = mock(NioChannelHandler.class);
handler = mock(NioChannelHandler.class);
InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance();
context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, channelBuffer);
context = new TestSocketChannelContext(channel, selector, exceptionHandler, handler, channelBuffer);
when(selector.isOnCurrentThread()).thenReturn(true);
when(selector.getIoBuffer()).thenAnswer(invocationOnMock -> {
@ -142,6 +142,11 @@ public class SocketChannelContextTests extends ESTestCase {
assertSame(ioException, exception.get());
}
public void testChannelActiveCallsHandler() throws IOException {
context.channelActive();
verify(handler).channelActive();
}
public void testWriteFailsIfClosing() {
context.closeChannel();
@ -158,7 +163,7 @@ public class SocketChannelContextTests extends ESTestCase {
ByteBuffer[] buffers = {ByteBuffer.wrap(createMessage(10))};
WriteOperation writeOperation = mock(WriteOperation.class);
when(readWriteHandler.createWriteOperation(context, buffers, listener)).thenReturn(writeOperation);
when(handler.createWriteOperation(context, buffers, listener)).thenReturn(writeOperation);
context.sendMessage(buffers, listener);
verify(selector).queueWrite(writeOpCaptor.capture());
@ -172,7 +177,7 @@ public class SocketChannelContextTests extends ESTestCase {
ByteBuffer[] buffers = {ByteBuffer.wrap(createMessage(10))};
WriteOperation writeOperation = mock(WriteOperation.class);
when(readWriteHandler.createWriteOperation(context, buffers, listener)).thenReturn(writeOperation);
when(handler.createWriteOperation(context, buffers, listener)).thenReturn(writeOperation);
context.sendMessage(buffers, listener);
verify(selector).queueWrite(writeOpCaptor.capture());
@ -186,16 +191,16 @@ public class SocketChannelContextTests extends ESTestCase {
ByteBuffer[] buffer = {ByteBuffer.allocate(10)};
FlushReadyWrite writeOperation = new FlushReadyWrite(context, buffer, listener);
when(readWriteHandler.writeToBytes(writeOperation)).thenReturn(Collections.singletonList(writeOperation));
when(handler.writeToBytes(writeOperation)).thenReturn(Collections.singletonList(writeOperation));
context.queueWriteOperation(writeOperation);
verify(readWriteHandler).writeToBytes(writeOperation);
verify(handler).writeToBytes(writeOperation);
assertTrue(context.readyForFlush());
}
public void testHandleReadBytesWillCheckForNewFlushOperations() throws IOException {
assertFalse(context.readyForFlush());
when(readWriteHandler.pollFlushOperations()).thenReturn(Collections.singletonList(mock(FlushOperation.class)));
when(handler.pollFlushOperations()).thenReturn(Collections.singletonList(mock(FlushOperation.class)));
context.handleReadBytes();
assertTrue(context.readyForFlush());
}
@ -205,14 +210,14 @@ public class SocketChannelContextTests extends ESTestCase {
try (SocketChannel realChannel = SocketChannel.open()) {
when(channel.getRawChannel()).thenReturn(realChannel);
InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance();
context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, channelBuffer);
context = new TestSocketChannelContext(channel, selector, exceptionHandler, handler, channelBuffer);
assertFalse(context.readyForFlush());
ByteBuffer[] buffer = {ByteBuffer.allocate(10)};
WriteOperation writeOperation = mock(WriteOperation.class);
BiConsumer<Void, Exception> listener2 = mock(BiConsumer.class);
when(readWriteHandler.writeToBytes(writeOperation)).thenReturn(Arrays.asList(new FlushOperation(buffer, listener),
when(handler.writeToBytes(writeOperation)).thenReturn(Arrays.asList(new FlushOperation(buffer, listener),
new FlushOperation(buffer, listener2)));
context.queueWriteOperation(writeOperation);
@ -233,7 +238,7 @@ public class SocketChannelContextTests extends ESTestCase {
try (SocketChannel realChannel = SocketChannel.open()) {
when(channel.getRawChannel()).thenReturn(realChannel);
InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance();
context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, channelBuffer);
context = new TestSocketChannelContext(channel, selector, exceptionHandler, handler, channelBuffer);
ByteBuffer[] buffer = {ByteBuffer.allocate(10)};
@ -241,7 +246,7 @@ public class SocketChannelContextTests extends ESTestCase {
assertFalse(context.readyForFlush());
when(channel.isOpen()).thenReturn(true);
when(readWriteHandler.pollFlushOperations()).thenReturn(Arrays.asList(new FlushOperation(buffer, listener),
when(handler.pollFlushOperations()).thenReturn(Arrays.asList(new FlushOperation(buffer, listener),
new FlushOperation(buffer, listener2)));
context.closeFromSelector();
@ -257,9 +262,9 @@ public class SocketChannelContextTests extends ESTestCase {
when(channel.getRawChannel()).thenReturn(realChannel);
when(channel.isOpen()).thenReturn(true);
InboundChannelBuffer buffer = InboundChannelBuffer.allocatingInstance();
BytesChannelContext context = new BytesChannelContext(channel, selector, exceptionHandler, readWriteHandler, buffer);
BytesChannelContext context = new BytesChannelContext(channel, selector, exceptionHandler, handler, buffer);
context.closeFromSelector();
verify(readWriteHandler).close();
verify(handler).close();
}
}
@ -271,7 +276,7 @@ public class SocketChannelContextTests extends ESTestCase {
IntFunction<Page> pageAllocator = (n) -> new Page(ByteBuffer.allocate(n), closer);
InboundChannelBuffer buffer = new InboundChannelBuffer(pageAllocator);
buffer.ensureCapacity(1);
TestSocketChannelContext context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, buffer);
TestSocketChannelContext context = new TestSocketChannelContext(channel, selector, exceptionHandler, handler, buffer);
context.closeFromSelector();
verify(closer).run();
}

View File

@ -58,7 +58,7 @@ public class HttpReadWriteHandler implements NioChannelHandler {
private final TaskScheduler taskScheduler;
private final LongSupplier nanoClock;
private final long readTimeoutNanos;
private boolean channelRegistered = false;
private boolean channelActive = false;
private boolean requestSinceReadTimeoutTrigger = false;
private int inFlightRequests = 0;
@ -91,8 +91,8 @@ public class HttpReadWriteHandler implements NioChannelHandler {
}
@Override
public void channelRegistered() {
channelRegistered = true;
public void channelActive() {
channelActive = true;
if (readTimeoutNanos > 0) {
scheduleReadTimeout();
}
@ -100,7 +100,7 @@ public class HttpReadWriteHandler implements NioChannelHandler {
@Override
public int consumeReads(InboundChannelBuffer channelBuffer) {
assert channelRegistered : "channelRegistered should have been called";
assert channelActive : "channelActive should have been called";
int bytesConsumed = adaptor.read(channelBuffer.sliceAndRetainPagesTo(channelBuffer.getIndex()));
Object message;
while ((message = adaptor.pollInboundMessage()) != null) {
@ -123,7 +123,7 @@ public class HttpReadWriteHandler implements NioChannelHandler {
public List<FlushOperation> writeToBytes(WriteOperation writeOperation) {
assert writeOperation.getObject() instanceof NioHttpResponse : "This channel only supports messages that are of type: "
+ NioHttpResponse.class + ". Found type: " + writeOperation.getObject().getClass() + ".";
assert channelRegistered : "channelRegistered should have been called";
assert channelActive : "channelActive should have been called";
--inFlightRequests;
assert inFlightRequests >= 0 : "Inflight requests should never drop below zero, found: " + inFlightRequests;
adaptor.write(writeOperation);

View File

@ -100,7 +100,7 @@ public class HttpReadWriteHandlerTests extends ESTestCase {
NioCorsConfig corsConfig = NioCorsConfigBuilder.forAnyOrigin().build();
handler = new HttpReadWriteHandler(channel, transport, httpHandlingSettings, corsConfig, taskScheduler, System::nanoTime);
handler.channelRegistered();
handler.channelActive();
}
public void testSuccessfulDecodeHttpRequest() throws IOException {
@ -334,7 +334,7 @@ public class HttpReadWriteHandlerTests extends ESTestCase {
Iterator<Integer> timeValues = Arrays.asList(0, 2, 4, 6, 8).iterator();
handler = new HttpReadWriteHandler(channel, transport, httpHandlingSettings, corsConfig, taskScheduler, timeValues::next);
handler.channelRegistered();
handler.channelActive();
prepareHandlerForResponse(handler);
SocketChannelContext context = mock(SocketChannelContext.class);
@ -381,7 +381,7 @@ public class HttpReadWriteHandlerTests extends ESTestCase {
NioCorsConfig corsConfig = NioHttpServerTransport.buildCorsConfig(settings);
HttpReadWriteHandler handler = new HttpReadWriteHandler(channel, transport, httpSettings, corsConfig, taskScheduler,
System::nanoTime);
handler.channelRegistered();
handler.channelActive();
prepareHandlerForResponse(handler);
DefaultFullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
if (originValue != null) {

View File

@ -227,7 +227,7 @@ class NioHttpClient implements Closeable {
}
@Override
public void channelRegistered() {}
public void channelActive() {}
@Override
public WriteOperation createWriteOperation(SocketChannelContext context, Object message, BiConsumer<Void, Exception> listener) {

View File

@ -92,6 +92,30 @@ public class TestEventHandler extends EventHandler {
}
}
@Override
protected void handleActive(ChannelContext<?> context) throws IOException {
final boolean registered = transportThreadWatchdog.register();
try {
super.handleActive(context);
} finally {
if (registered) {
transportThreadWatchdog.unregister();
}
}
}
@Override
protected void activeException(ChannelContext<?> context, Exception exception) {
final boolean registered = transportThreadWatchdog.register();
try {
super.activeException(context, exception);
} finally {
if (registered) {
transportThreadWatchdog.unregister();
}
}
}
public void handleConnect(SocketChannelContext context) throws IOException {
assert hasConnectedMap.contains(context) == false : "handleConnect should only be called is a channel is not yet connected";
final boolean registered = transportThreadWatchdog.register();

View File

@ -28,9 +28,9 @@ public final class NioIPFilter extends DelegatingHandler {
}
@Override
public void channelRegistered() {
public void channelActive() {
if (filter.accept(profile, remoteAddress)) {
super.channelRegistered();
super.channelActive();
} else {
denied = true;
}

View File

@ -55,8 +55,8 @@ public final class SSLChannelContext extends SocketChannelContext {
}
@Override
public void register() throws IOException {
super.register();
protected void channelActive() throws IOException {
super.channelActive();
sslDriver.init();
SSLOutboundBuffer outboundBuffer = sslDriver.getOutboundBuffer();
if (outboundBuffer.hasEncryptedBytesToFlush()) {
@ -179,10 +179,17 @@ public final class SSLChannelContext extends SocketChannelContext {
@Override
public void closeChannel() {
if (isClosing.compareAndSet(false, true)) {
// The model for closing channels will change at some point, removing the need for this "schedule
// a write" signal. But for now, we need to handle the edge case where the channel is not
// registered.
if (getSelectionKey() == null) {
getSelector().queueChannelClose(channel);
} else {
WriteOperation writeOperation = new CloseNotifyOperation(this);
getSelector().queueWrite(writeOperation);
}
}
}
@Override
public void closeFromSelector() throws IOException {

View File

@ -82,8 +82,8 @@ public class NioIPFilterTests extends ESTestCase {
InetSocketAddress localhostAddr = new InetSocketAddress(InetAddresses.forString("127.0.0.1"), 12345);
NioChannelHandler delegate = mock(NioChannelHandler.class);
NioIPFilter nioIPFilter = new NioIPFilter(delegate, localhostAddr, ipFilter, profile);
nioIPFilter.channelRegistered();
verify(delegate).channelRegistered();
nioIPFilter.channelActive();
verify(delegate).channelActive();
assertFalse(nioIPFilter.closeNow());
}
@ -91,8 +91,8 @@ public class NioIPFilterTests extends ESTestCase {
InetSocketAddress localhostAddr = new InetSocketAddress(InetAddresses.forString("10.0.0.8"), 12345);
NioChannelHandler delegate = mock(NioChannelHandler.class);
NioIPFilter nioIPFilter = new NioIPFilter(delegate, localhostAddr, ipFilter, profile);
nioIPFilter.channelRegistered();
verify(delegate, times(0)).channelRegistered();
nioIPFilter.channelActive();
verify(delegate, times(0)).channelActive();
assertTrue(nioIPFilter.closeNow());
}
}

View File

@ -24,6 +24,7 @@ import org.mockito.stubbing.Answer;
import javax.net.ssl.SSLException;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;
import java.util.function.BiConsumer;
@ -73,6 +74,7 @@ public class SSLChannelContextTests extends ESTestCase {
when(channel.getRawChannel()).thenReturn(rawChannel);
exceptionHandler = mock(Consumer.class);
context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readWriteHandler, channelBuffer);
context.setSelectionKey(mock(SelectionKey.class));
when(selector.isOnCurrentThread()).thenReturn(true);
when(selector.getTaskScheduler()).thenReturn(nioTimer);
@ -331,6 +333,7 @@ public class SSLChannelContextTests extends ESTestCase {
when(channel.getRawChannel()).thenReturn(realChannel);
TestReadWriteHandler readWriteHandler = new TestReadWriteHandler(readConsumer);
context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readWriteHandler, channelBuffer);
context.setSelectionKey(mock(SelectionKey.class));
context.closeChannel();
ArgumentCaptor<WriteOperation> captor = ArgumentCaptor.forClass(WriteOperation.class);
verify(selector).queueWrite(captor.capture());
@ -345,18 +348,7 @@ public class SSLChannelContextTests extends ESTestCase {
}
}
public void testInitiateCloseFromDifferentThreadSchedulesCloseNotify() throws SSLException {
when(selector.isOnCurrentThread()).thenReturn(false, true);
context.closeChannel();
ArgumentCaptor<FlushReadyWrite> captor = ArgumentCaptor.forClass(FlushReadyWrite.class);
verify(selector).queueWrite(captor.capture());
context.queueWriteOperation(captor.getValue());
verify(sslDriver).initiateClose();
}
public void testInitiateCloseFromSameThreadSchedulesCloseNotify() throws SSLException {
public void testInitiateCloseSchedulesCloseNotify() throws SSLException {
context.closeChannel();
ArgumentCaptor<WriteOperation> captor = ArgumentCaptor.forClass(WriteOperation.class);
@ -366,8 +358,15 @@ public class SSLChannelContextTests extends ESTestCase {
verify(sslDriver).initiateClose();
}
public void testInitiateUnregisteredScheduledDirectClose() throws SSLException {
context.setSelectionKey(null);
context.closeChannel();
verify(selector).queueChannelClose(channel);
}
@SuppressWarnings("unchecked")
public void testRegisterInitiatesDriver() throws IOException {
public void testActiveInitiatesDriver() throws IOException {
try (Selector realSelector = Selector.open();
SocketChannel realSocket = SocketChannel.open()) {
realSocket.configureBlocking(false);
@ -375,7 +374,7 @@ public class SSLChannelContextTests extends ESTestCase {
when(channel.getRawChannel()).thenReturn(realSocket);
TestReadWriteHandler readWriteHandler = new TestReadWriteHandler(readConsumer);
context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readWriteHandler, channelBuffer);
context.register();
context.channelActive();
verify(sslDriver).init();
}
}