From 0a352486e8f18fbc0da062a3d021ddead9c9f486 Mon Sep 17 00:00:00 2001 From: Tim Brooks Date: Tue, 16 Jul 2019 18:46:41 -0400 Subject: [PATCH] Isolate nio channel registered from channel active (#44388) Registering a channel with a selector is a required operation for the channel to be handled properly. Currently, we mix the registeration with other setup operations (ip filtering, SSL initiation, etc). However, a fail to register is fatal. This PR modifies how registeration occurs to immediately close the channel if it fails. There are still two clear loopholes for how a user can interact with a channel even if registration fails. 1. through the exception handler. 2. through the channel accepted callback. These can perhaps be improved in the future. For now, this PR prevents writes from proceeding if the channel is not registered. --- .../elasticsearch/nio/BytesWriteHandler.java | 2 +- .../org/elasticsearch/nio/ChannelContext.java | 10 ++-- .../elasticsearch/nio/DelegatingHandler.java | 4 +- .../org/elasticsearch/nio/EventHandler.java | 38 +++++++++----- .../elasticsearch/nio/NioChannelHandler.java | 4 +- .../org/elasticsearch/nio/NioSelector.java | 52 ++++++++++++------- .../nio/SocketChannelContext.java | 5 +- .../elasticsearch/nio/EventHandlerTests.java | 37 ++++++------- .../elasticsearch/nio/NioSelectorTests.java | 28 ++++++++++ .../nio/SocketChannelContextTests.java | 35 +++++++------ .../http/nio/HttpReadWriteHandler.java | 10 ++-- .../http/nio/HttpReadWriteHandlerTests.java | 6 +-- .../elasticsearch/http/nio/NioHttpClient.java | 2 +- .../transport/nio/TestEventHandler.java | 24 +++++++++ .../security/transport/nio/NioIPFilter.java | 4 +- .../transport/nio/SSLChannelContext.java | 15 ++++-- .../transport/nio/NioIPFilterTests.java | 8 +-- .../transport/nio/SSLChannelContextTests.java | 27 +++++----- 18 files changed, 201 insertions(+), 110 deletions(-) diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/BytesWriteHandler.java b/libs/nio/src/main/java/org/elasticsearch/nio/BytesWriteHandler.java index 48d83d21692..8e590c830b9 100644 --- a/libs/nio/src/main/java/org/elasticsearch/nio/BytesWriteHandler.java +++ b/libs/nio/src/main/java/org/elasticsearch/nio/BytesWriteHandler.java @@ -35,7 +35,7 @@ public abstract class BytesWriteHandler implements NioChannelHandler { } @Override - public void channelRegistered() {} + public void channelActive() {} @Override public List writeToBytes(WriteOperation writeOperation) { diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/ChannelContext.java b/libs/nio/src/main/java/org/elasticsearch/nio/ChannelContext.java index a2663385daa..a030f68fe8b 100644 --- a/libs/nio/src/main/java/org/elasticsearch/nio/ChannelContext.java +++ b/libs/nio/src/main/java/org/elasticsearch/nio/ChannelContext.java @@ -50,17 +50,19 @@ public abstract class ChannelContext context) throws IOException { context.register(); - SelectionKey selectionKey = context.getSelectionKey(); - selectionKey.attach(context); + assert context.getSelectionKey() != null : "SelectionKey should not be null after registration"; + assert context.getSelectionKey().attachment() != null : "Attachment should not be null after registration"; + } + + /** + * This method is called when an attempt to register a channel throws an exception. + * + * @param context that was registered + * @param exception that occurred + */ + protected void registrationException(ChannelContext context, Exception exception) { + context.handleException(exception); + } + + /** + * This method is called after a NioChannel is active with the selector. It should only be called once + * per channel. + * + * @param context that was marked active + */ + protected void handleActive(ChannelContext context) throws IOException { + context.channelActive(); if (context instanceof SocketChannelContext) { if (((SocketChannelContext) context).readyForFlush()) { SelectionKeyUtils.setConnectReadAndWriteInterested(context.getSelectionKey()); @@ -78,12 +98,12 @@ public class EventHandler { } /** - * This method is called when an attempt to register a channel throws an exception. + * This method is called when setting a channel to active throws an exception. * - * @param context that was registered + * @param context that was marked active * @param exception that occurred */ - protected void registrationException(ChannelContext context, Exception exception) { + protected void activeException(ChannelContext context, Exception exception) { context.handleException(exception); } @@ -180,15 +200,9 @@ public class EventHandler { closeException(context, e); } } else { - boolean pendingWrites = context.readyForFlush(); SelectionKey selectionKey = context.getSelectionKey(); - if (selectionKey == null) { - if (pendingWrites) { - writeException(context, new IllegalStateException("Tried to write to an not yet registered channel")); - } - return; - } boolean currentlyWriteInterested = SelectionKeyUtils.isWriteInterested(selectionKey); + boolean pendingWrites = context.readyForFlush(); if (currentlyWriteInterested == false && pendingWrites) { SelectionKeyUtils.setWriteInterested(selectionKey); } else if (currentlyWriteInterested && pendingWrites == false) { diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/NioChannelHandler.java b/libs/nio/src/main/java/org/elasticsearch/nio/NioChannelHandler.java index 61bda9a4507..2d91e769368 100644 --- a/libs/nio/src/main/java/org/elasticsearch/nio/NioChannelHandler.java +++ b/libs/nio/src/main/java/org/elasticsearch/nio/NioChannelHandler.java @@ -29,9 +29,9 @@ import java.util.function.BiConsumer; public interface NioChannelHandler { /** - * This method is called when the channel is registered with its selector. + * This method is called when the channel is active for use. */ - void channelRegistered(); + void channelActive(); /** * This method is called when a message is queued with a channel. It can be called from any thread. diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/NioSelector.java b/libs/nio/src/main/java/org/elasticsearch/nio/NioSelector.java index 175c7661813..cbc069c8f36 100644 --- a/libs/nio/src/main/java/org/elasticsearch/nio/NioSelector.java +++ b/libs/nio/src/main/java/org/elasticsearch/nio/NioSelector.java @@ -340,21 +340,30 @@ public class NioSelector implements Closeable { private void writeToChannel(WriteOperation writeOperation) { assertOnSelectorThread(); SocketChannelContext context = writeOperation.getChannel(); - // If the channel does not currently have anything that is ready to flush, we should flush after - // the write operation is queued. - boolean shouldFlushAfterQueuing = context.readyForFlush() == false; - try { - context.queueWriteOperation(writeOperation); - } catch (Exception e) { - shouldFlushAfterQueuing = false; - executeFailedListener(writeOperation.getListener(), e); - } - if (shouldFlushAfterQueuing) { - if (context.selectorShouldClose() == false) { - handleWrite(context); + if (context.isOpen() == false) { + executeFailedListener(writeOperation.getListener(), new ClosedChannelException()); + } else if (context.getSelectionKey() == null) { + // This should very rarely happen. The only times a channel is exposed outside the event loop, + // but might not registered is through the exception handler and channel accepted callbacks. + executeFailedListener(writeOperation.getListener(), new IllegalStateException("Channel not registered")); + } else { + // If the channel does not currently have anything that is ready to flush, we should flush after + // the write operation is queued. + boolean shouldFlushAfterQueuing = context.readyForFlush() == false; + try { + context.queueWriteOperation(writeOperation); + } catch (Exception e) { + shouldFlushAfterQueuing = false; + executeFailedListener(writeOperation.getListener(), e); + } + + if (shouldFlushAfterQueuing) { + if (context.selectorShouldClose() == false) { + handleWrite(context); + } + eventHandler.postHandling(context); } - eventHandler.postHandling(context); } } @@ -435,14 +444,25 @@ public class NioSelector implements Closeable { try { if (newChannel.isOpen()) { eventHandler.handleRegistration(newChannel); + channelActive(newChannel); if (newChannel instanceof SocketChannelContext) { attemptConnect((SocketChannelContext) newChannel, false); } } else { eventHandler.registrationException(newChannel, new ClosedChannelException()); + closeChannel(newChannel); } } catch (Exception e) { eventHandler.registrationException(newChannel, e); + closeChannel(newChannel); + } + } + + private void channelActive(ChannelContext newChannel) { + try { + eventHandler.handleActive(newChannel); + } catch (IOException e) { + eventHandler.activeException(newChannel, e); } } @@ -464,11 +484,7 @@ public class NioSelector implements Closeable { private void handleQueuedWrites() { WriteOperation writeOperation; while ((writeOperation = queuedWrites.poll()) != null) { - if (writeOperation.getChannel().isOpen()) { - writeToChannel(writeOperation); - } else { - executeFailedListener(writeOperation.getListener(), new ClosedChannelException()); - } + writeToChannel(writeOperation); } } diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java b/libs/nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java index f77ccb17aef..bc93466b58a 100644 --- a/libs/nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java +++ b/libs/nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java @@ -156,9 +156,8 @@ public abstract class SocketChannelContext extends ChannelContext } @Override - protected void register() throws IOException { - super.register(); - readWriteHandler.channelRegistered(); + protected void channelActive() throws IOException { + readWriteHandler.channelActive(); } @Override diff --git a/libs/nio/src/test/java/org/elasticsearch/nio/EventHandlerTests.java b/libs/nio/src/test/java/org/elasticsearch/nio/EventHandlerTests.java index 578890b152f..726d87317ff 100644 --- a/libs/nio/src/test/java/org/elasticsearch/nio/EventHandlerTests.java +++ b/libs/nio/src/test/java/org/elasticsearch/nio/EventHandlerTests.java @@ -81,32 +81,25 @@ public class EventHandlerTests extends ESTestCase { } public void testRegisterCallsContext() throws IOException { - NioSocketChannel channel = mock(NioSocketChannel.class); - SocketChannelContext channelContext = mock(SocketChannelContext.class); - when(channel.getContext()).thenReturn(channelContext); - when(channelContext.getSelectionKey()).thenReturn(new TestSelectionKey(0)); + ChannelContext channelContext = randomBoolean() ? mock(SocketChannelContext.class) : mock(ServerChannelContext.class); + TestSelectionKey attachment = new TestSelectionKey(0); + when(channelContext.getSelectionKey()).thenReturn(attachment); + attachment.attach(channelContext); handler.handleRegistration(channelContext); verify(channelContext).register(); } - public void testRegisterNonServerAddsOP_CONNECTAndOP_READInterest() throws IOException { + public void testActiveNonServerAddsOP_CONNECTAndOP_READInterest() throws IOException { SocketChannelContext context = mock(SocketChannelContext.class); when(context.getSelectionKey()).thenReturn(new TestSelectionKey(0)); - handler.handleRegistration(context); + handler.handleActive(context); assertEquals(SelectionKey.OP_READ | SelectionKey.OP_CONNECT, context.getSelectionKey().interestOps()); } - public void testRegisterAddsAttachment() throws IOException { - ChannelContext context = randomBoolean() ? mock(SocketChannelContext.class) : mock(ServerChannelContext.class); - when(context.getSelectionKey()).thenReturn(new TestSelectionKey(0)); - handler.handleRegistration(context); - assertEquals(context, context.getSelectionKey().attachment()); - } - - public void testHandleServerRegisterSetsOP_ACCEPTInterest() throws IOException { - assertNull(serverContext.getSelectionKey()); - - handler.handleRegistration(serverContext); + public void testHandleServerActiveSetsOP_ACCEPTInterest() throws IOException { + ServerChannelContext serverContext = mock(ServerChannelContext.class); + when(serverContext.getSelectionKey()).thenReturn(new TestSelectionKey(0)); + handler.handleActive(serverContext); assertEquals(SelectionKey.OP_ACCEPT, serverContext.getSelectionKey().interestOps()); } @@ -141,11 +134,11 @@ public class EventHandlerTests extends ESTestCase { verify(serverChannelContext).handleException(exception); } - public void testRegisterWithPendingWritesAddsOP_CONNECTAndOP_READAndOP_WRITEInterest() throws IOException { + public void testActiveWithPendingWritesAddsOP_CONNECTAndOP_READAndOP_WRITEInterest() throws IOException { FlushReadyWrite flushReadyWrite = mock(FlushReadyWrite.class); when(readWriteHandler.writeToBytes(flushReadyWrite)).thenReturn(Collections.singletonList(flushReadyWrite)); context.queueWriteOperation(flushReadyWrite); - handler.handleRegistration(context); + handler.handleActive(context); assertEquals(SelectionKey.OP_READ | SelectionKey.OP_CONNECT | SelectionKey.OP_WRITE, context.getSelectionKey().interestOps()); } @@ -266,7 +259,9 @@ public class EventHandlerTests extends ESTestCase { @Override public void register() { - setSelectionKey(new TestSelectionKey(0)); + TestSelectionKey selectionKey = new TestSelectionKey(0); + setSelectionKey(selectionKey); + selectionKey.attach(this); } } @@ -280,7 +275,9 @@ public class EventHandlerTests extends ESTestCase { @Override public void register() { + TestSelectionKey selectionKey = new TestSelectionKey(0); setSelectionKey(new TestSelectionKey(0)); + selectionKey.attach(this); } } } diff --git a/libs/nio/src/test/java/org/elasticsearch/nio/NioSelectorTests.java b/libs/nio/src/test/java/org/elasticsearch/nio/NioSelectorTests.java index 89a01d02ed4..f7bf4bc24be 100644 --- a/libs/nio/src/test/java/org/elasticsearch/nio/NioSelectorTests.java +++ b/libs/nio/src/test/java/org/elasticsearch/nio/NioSelectorTests.java @@ -212,6 +212,7 @@ public class NioSelectorTests extends ESTestCase { selector.preSelect(); verify(eventHandler).handleRegistration(serverChannelContext); + verify(eventHandler).handleActive(serverChannelContext); } public void testClosedServerChannelWillNotBeRegistered() { @@ -230,7 +231,20 @@ public class NioSelectorTests extends ESTestCase { selector.preSelect(); + verify(eventHandler, times(0)).handleActive(serverChannelContext); verify(eventHandler).registrationException(serverChannelContext, closedChannelException); + verify(eventHandler).handleClose(serverChannelContext); + } + + public void testChannelActiveException() throws Exception { + executeOnNewThread(() -> selector.scheduleForRegistration(serverChannel)); + IOException ioException = new IOException(); + doThrow(ioException).when(eventHandler).handleActive(serverChannelContext); + + selector.preSelect(); + + verify(eventHandler).handleActive(serverChannelContext); + verify(eventHandler).activeException(serverChannelContext, ioException); } public void testClosedSocketChannelWillNotBeRegistered() throws Exception { @@ -241,6 +255,7 @@ public class NioSelectorTests extends ESTestCase { verify(eventHandler).registrationException(same(channelContext), any(ClosedChannelException.class)); verify(eventHandler, times(0)).handleConnect(channelContext); + verify(eventHandler).handleClose(channelContext); } public void testRegisterSocketChannelFailsDueToException() throws InterruptedException { @@ -253,7 +268,9 @@ public class NioSelectorTests extends ESTestCase { selector.preSelect(); verify(eventHandler).registrationException(channelContext, closedChannelException); + verify(eventHandler, times(0)).handleActive(serverChannelContext); verify(eventHandler, times(0)).handleConnect(channelContext); + verify(eventHandler).handleClose(channelContext); }); } @@ -313,6 +330,17 @@ public class NioSelectorTests extends ESTestCase { verify(listener).accept(isNull(Void.class), any(ClosedChannelException.class)); } + public void testQueueWriteChannelIsUnregistered() throws Exception { + WriteOperation writeOperation = new FlushReadyWrite(channelContext, buffers, listener); + + executeOnNewThread(() -> selector.queueWrite(writeOperation)); + when(channelContext.getSelectionKey()).thenReturn(null); + selector.preSelect(); + + verify(channelContext, times(0)).queueWriteOperation(writeOperation); + verify(listener).accept(isNull(Void.class), any(IllegalStateException.class)); + } + public void testQueueWriteSuccessful() throws Exception { WriteOperation writeOperation = new FlushReadyWrite(channelContext, buffers, listener); executeOnNewThread(() -> selector.queueWrite(writeOperation)); diff --git a/libs/nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java b/libs/nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java index 5563ccc4306..210a27aa109 100644 --- a/libs/nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java +++ b/libs/nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java @@ -53,7 +53,7 @@ public class SocketChannelContextTests extends ESTestCase { private NioSocketChannel channel; private BiConsumer listener; private NioSelector selector; - private NioChannelHandler readWriteHandler; + private NioChannelHandler handler; private ByteBuffer ioBuffer = ByteBuffer.allocate(1024); @SuppressWarnings("unchecked") @@ -67,9 +67,9 @@ public class SocketChannelContextTests extends ESTestCase { when(channel.getRawChannel()).thenReturn(rawChannel); exceptionHandler = mock(Consumer.class); selector = mock(NioSelector.class); - readWriteHandler = mock(NioChannelHandler.class); + handler = mock(NioChannelHandler.class); InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance(); - context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, channelBuffer); + context = new TestSocketChannelContext(channel, selector, exceptionHandler, handler, channelBuffer); when(selector.isOnCurrentThread()).thenReturn(true); when(selector.getIoBuffer()).thenAnswer(invocationOnMock -> { @@ -142,6 +142,11 @@ public class SocketChannelContextTests extends ESTestCase { assertSame(ioException, exception.get()); } + public void testChannelActiveCallsHandler() throws IOException { + context.channelActive(); + verify(handler).channelActive(); + } + public void testWriteFailsIfClosing() { context.closeChannel(); @@ -158,7 +163,7 @@ public class SocketChannelContextTests extends ESTestCase { ByteBuffer[] buffers = {ByteBuffer.wrap(createMessage(10))}; WriteOperation writeOperation = mock(WriteOperation.class); - when(readWriteHandler.createWriteOperation(context, buffers, listener)).thenReturn(writeOperation); + when(handler.createWriteOperation(context, buffers, listener)).thenReturn(writeOperation); context.sendMessage(buffers, listener); verify(selector).queueWrite(writeOpCaptor.capture()); @@ -172,7 +177,7 @@ public class SocketChannelContextTests extends ESTestCase { ByteBuffer[] buffers = {ByteBuffer.wrap(createMessage(10))}; WriteOperation writeOperation = mock(WriteOperation.class); - when(readWriteHandler.createWriteOperation(context, buffers, listener)).thenReturn(writeOperation); + when(handler.createWriteOperation(context, buffers, listener)).thenReturn(writeOperation); context.sendMessage(buffers, listener); verify(selector).queueWrite(writeOpCaptor.capture()); @@ -186,16 +191,16 @@ public class SocketChannelContextTests extends ESTestCase { ByteBuffer[] buffer = {ByteBuffer.allocate(10)}; FlushReadyWrite writeOperation = new FlushReadyWrite(context, buffer, listener); - when(readWriteHandler.writeToBytes(writeOperation)).thenReturn(Collections.singletonList(writeOperation)); + when(handler.writeToBytes(writeOperation)).thenReturn(Collections.singletonList(writeOperation)); context.queueWriteOperation(writeOperation); - verify(readWriteHandler).writeToBytes(writeOperation); + verify(handler).writeToBytes(writeOperation); assertTrue(context.readyForFlush()); } public void testHandleReadBytesWillCheckForNewFlushOperations() throws IOException { assertFalse(context.readyForFlush()); - when(readWriteHandler.pollFlushOperations()).thenReturn(Collections.singletonList(mock(FlushOperation.class))); + when(handler.pollFlushOperations()).thenReturn(Collections.singletonList(mock(FlushOperation.class))); context.handleReadBytes(); assertTrue(context.readyForFlush()); } @@ -205,14 +210,14 @@ public class SocketChannelContextTests extends ESTestCase { try (SocketChannel realChannel = SocketChannel.open()) { when(channel.getRawChannel()).thenReturn(realChannel); InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance(); - context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, channelBuffer); + context = new TestSocketChannelContext(channel, selector, exceptionHandler, handler, channelBuffer); assertFalse(context.readyForFlush()); ByteBuffer[] buffer = {ByteBuffer.allocate(10)}; WriteOperation writeOperation = mock(WriteOperation.class); BiConsumer listener2 = mock(BiConsumer.class); - when(readWriteHandler.writeToBytes(writeOperation)).thenReturn(Arrays.asList(new FlushOperation(buffer, listener), + when(handler.writeToBytes(writeOperation)).thenReturn(Arrays.asList(new FlushOperation(buffer, listener), new FlushOperation(buffer, listener2))); context.queueWriteOperation(writeOperation); @@ -233,7 +238,7 @@ public class SocketChannelContextTests extends ESTestCase { try (SocketChannel realChannel = SocketChannel.open()) { when(channel.getRawChannel()).thenReturn(realChannel); InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance(); - context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, channelBuffer); + context = new TestSocketChannelContext(channel, selector, exceptionHandler, handler, channelBuffer); ByteBuffer[] buffer = {ByteBuffer.allocate(10)}; @@ -241,7 +246,7 @@ public class SocketChannelContextTests extends ESTestCase { assertFalse(context.readyForFlush()); when(channel.isOpen()).thenReturn(true); - when(readWriteHandler.pollFlushOperations()).thenReturn(Arrays.asList(new FlushOperation(buffer, listener), + when(handler.pollFlushOperations()).thenReturn(Arrays.asList(new FlushOperation(buffer, listener), new FlushOperation(buffer, listener2))); context.closeFromSelector(); @@ -257,9 +262,9 @@ public class SocketChannelContextTests extends ESTestCase { when(channel.getRawChannel()).thenReturn(realChannel); when(channel.isOpen()).thenReturn(true); InboundChannelBuffer buffer = InboundChannelBuffer.allocatingInstance(); - BytesChannelContext context = new BytesChannelContext(channel, selector, exceptionHandler, readWriteHandler, buffer); + BytesChannelContext context = new BytesChannelContext(channel, selector, exceptionHandler, handler, buffer); context.closeFromSelector(); - verify(readWriteHandler).close(); + verify(handler).close(); } } @@ -271,7 +276,7 @@ public class SocketChannelContextTests extends ESTestCase { IntFunction pageAllocator = (n) -> new Page(ByteBuffer.allocate(n), closer); InboundChannelBuffer buffer = new InboundChannelBuffer(pageAllocator); buffer.ensureCapacity(1); - TestSocketChannelContext context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, buffer); + TestSocketChannelContext context = new TestSocketChannelContext(channel, selector, exceptionHandler, handler, buffer); context.closeFromSelector(); verify(closer).run(); } diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpReadWriteHandler.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpReadWriteHandler.java index c603e20ffc9..9802b24a14a 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpReadWriteHandler.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpReadWriteHandler.java @@ -58,7 +58,7 @@ public class HttpReadWriteHandler implements NioChannelHandler { private final TaskScheduler taskScheduler; private final LongSupplier nanoClock; private final long readTimeoutNanos; - private boolean channelRegistered = false; + private boolean channelActive = false; private boolean requestSinceReadTimeoutTrigger = false; private int inFlightRequests = 0; @@ -91,8 +91,8 @@ public class HttpReadWriteHandler implements NioChannelHandler { } @Override - public void channelRegistered() { - channelRegistered = true; + public void channelActive() { + channelActive = true; if (readTimeoutNanos > 0) { scheduleReadTimeout(); } @@ -100,7 +100,7 @@ public class HttpReadWriteHandler implements NioChannelHandler { @Override public int consumeReads(InboundChannelBuffer channelBuffer) { - assert channelRegistered : "channelRegistered should have been called"; + assert channelActive : "channelActive should have been called"; int bytesConsumed = adaptor.read(channelBuffer.sliceAndRetainPagesTo(channelBuffer.getIndex())); Object message; while ((message = adaptor.pollInboundMessage()) != null) { @@ -123,7 +123,7 @@ public class HttpReadWriteHandler implements NioChannelHandler { public List writeToBytes(WriteOperation writeOperation) { assert writeOperation.getObject() instanceof NioHttpResponse : "This channel only supports messages that are of type: " + NioHttpResponse.class + ". Found type: " + writeOperation.getObject().getClass() + "."; - assert channelRegistered : "channelRegistered should have been called"; + assert channelActive : "channelActive should have been called"; --inFlightRequests; assert inFlightRequests >= 0 : "Inflight requests should never drop below zero, found: " + inFlightRequests; adaptor.write(writeOperation); diff --git a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/HttpReadWriteHandlerTests.java b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/HttpReadWriteHandlerTests.java index 93a846ea90f..253487179b2 100644 --- a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/HttpReadWriteHandlerTests.java +++ b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/HttpReadWriteHandlerTests.java @@ -100,7 +100,7 @@ public class HttpReadWriteHandlerTests extends ESTestCase { NioCorsConfig corsConfig = NioCorsConfigBuilder.forAnyOrigin().build(); handler = new HttpReadWriteHandler(channel, transport, httpHandlingSettings, corsConfig, taskScheduler, System::nanoTime); - handler.channelRegistered(); + handler.channelActive(); } public void testSuccessfulDecodeHttpRequest() throws IOException { @@ -334,7 +334,7 @@ public class HttpReadWriteHandlerTests extends ESTestCase { Iterator timeValues = Arrays.asList(0, 2, 4, 6, 8).iterator(); handler = new HttpReadWriteHandler(channel, transport, httpHandlingSettings, corsConfig, taskScheduler, timeValues::next); - handler.channelRegistered(); + handler.channelActive(); prepareHandlerForResponse(handler); SocketChannelContext context = mock(SocketChannelContext.class); @@ -381,7 +381,7 @@ public class HttpReadWriteHandlerTests extends ESTestCase { NioCorsConfig corsConfig = NioHttpServerTransport.buildCorsConfig(settings); HttpReadWriteHandler handler = new HttpReadWriteHandler(channel, transport, httpSettings, corsConfig, taskScheduler, System::nanoTime); - handler.channelRegistered(); + handler.channelActive(); prepareHandlerForResponse(handler); DefaultFullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); if (originValue != null) { diff --git a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpClient.java b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpClient.java index ed55007f3ba..08fcf8bb44a 100644 --- a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpClient.java +++ b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpClient.java @@ -227,7 +227,7 @@ class NioHttpClient implements Closeable { } @Override - public void channelRegistered() {} + public void channelActive() {} @Override public WriteOperation createWriteOperation(SocketChannelContext context, Object message, BiConsumer listener) { diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/TestEventHandler.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/TestEventHandler.java index 069e19c3455..4a1c6f5deb6 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/TestEventHandler.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/TestEventHandler.java @@ -92,6 +92,30 @@ public class TestEventHandler extends EventHandler { } } + @Override + protected void handleActive(ChannelContext context) throws IOException { + final boolean registered = transportThreadWatchdog.register(); + try { + super.handleActive(context); + } finally { + if (registered) { + transportThreadWatchdog.unregister(); + } + } + } + + @Override + protected void activeException(ChannelContext context, Exception exception) { + final boolean registered = transportThreadWatchdog.register(); + try { + super.activeException(context, exception); + } finally { + if (registered) { + transportThreadWatchdog.unregister(); + } + } + } + public void handleConnect(SocketChannelContext context) throws IOException { assert hasConnectedMap.contains(context) == false : "handleConnect should only be called is a channel is not yet connected"; final boolean registered = transportThreadWatchdog.register(); diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/NioIPFilter.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/NioIPFilter.java index 12f6b67d672..b90a6f4991c 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/NioIPFilter.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/NioIPFilter.java @@ -28,9 +28,9 @@ public final class NioIPFilter extends DelegatingHandler { } @Override - public void channelRegistered() { + public void channelActive() { if (filter.accept(profile, remoteAddress)) { - super.channelRegistered(); + super.channelActive(); } else { denied = true; } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java index 8947447ef58..6a1684dd024 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java @@ -55,8 +55,8 @@ public final class SSLChannelContext extends SocketChannelContext { } @Override - public void register() throws IOException { - super.register(); + protected void channelActive() throws IOException { + super.channelActive(); sslDriver.init(); SSLOutboundBuffer outboundBuffer = sslDriver.getOutboundBuffer(); if (outboundBuffer.hasEncryptedBytesToFlush()) { @@ -179,8 +179,15 @@ public final class SSLChannelContext extends SocketChannelContext { @Override public void closeChannel() { if (isClosing.compareAndSet(false, true)) { - WriteOperation writeOperation = new CloseNotifyOperation(this); - getSelector().queueWrite(writeOperation); + // The model for closing channels will change at some point, removing the need for this "schedule + // a write" signal. But for now, we need to handle the edge case where the channel is not + // registered. + if (getSelectionKey() == null) { + getSelector().queueChannelClose(channel); + } else { + WriteOperation writeOperation = new CloseNotifyOperation(this); + getSelector().queueWrite(writeOperation); + } } } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/NioIPFilterTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/NioIPFilterTests.java index e7612c0c0d7..842c9f031ef 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/NioIPFilterTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/NioIPFilterTests.java @@ -82,8 +82,8 @@ public class NioIPFilterTests extends ESTestCase { InetSocketAddress localhostAddr = new InetSocketAddress(InetAddresses.forString("127.0.0.1"), 12345); NioChannelHandler delegate = mock(NioChannelHandler.class); NioIPFilter nioIPFilter = new NioIPFilter(delegate, localhostAddr, ipFilter, profile); - nioIPFilter.channelRegistered(); - verify(delegate).channelRegistered(); + nioIPFilter.channelActive(); + verify(delegate).channelActive(); assertFalse(nioIPFilter.closeNow()); } @@ -91,8 +91,8 @@ public class NioIPFilterTests extends ESTestCase { InetSocketAddress localhostAddr = new InetSocketAddress(InetAddresses.forString("10.0.0.8"), 12345); NioChannelHandler delegate = mock(NioChannelHandler.class); NioIPFilter nioIPFilter = new NioIPFilter(delegate, localhostAddr, ipFilter, profile); - nioIPFilter.channelRegistered(); - verify(delegate, times(0)).channelRegistered(); + nioIPFilter.channelActive(); + verify(delegate, times(0)).channelActive(); assertTrue(nioIPFilter.closeNow()); } } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java index 7efff1c0e26..8e0a5ad23af 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java @@ -24,6 +24,7 @@ import org.mockito.stubbing.Answer; import javax.net.ssl.SSLException; import java.io.IOException; import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; import java.nio.channels.Selector; import java.nio.channels.SocketChannel; import java.util.function.BiConsumer; @@ -73,6 +74,7 @@ public class SSLChannelContextTests extends ESTestCase { when(channel.getRawChannel()).thenReturn(rawChannel); exceptionHandler = mock(Consumer.class); context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readWriteHandler, channelBuffer); + context.setSelectionKey(mock(SelectionKey.class)); when(selector.isOnCurrentThread()).thenReturn(true); when(selector.getTaskScheduler()).thenReturn(nioTimer); @@ -331,6 +333,7 @@ public class SSLChannelContextTests extends ESTestCase { when(channel.getRawChannel()).thenReturn(realChannel); TestReadWriteHandler readWriteHandler = new TestReadWriteHandler(readConsumer); context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readWriteHandler, channelBuffer); + context.setSelectionKey(mock(SelectionKey.class)); context.closeChannel(); ArgumentCaptor captor = ArgumentCaptor.forClass(WriteOperation.class); verify(selector).queueWrite(captor.capture()); @@ -345,18 +348,7 @@ public class SSLChannelContextTests extends ESTestCase { } } - public void testInitiateCloseFromDifferentThreadSchedulesCloseNotify() throws SSLException { - when(selector.isOnCurrentThread()).thenReturn(false, true); - context.closeChannel(); - - ArgumentCaptor captor = ArgumentCaptor.forClass(FlushReadyWrite.class); - verify(selector).queueWrite(captor.capture()); - - context.queueWriteOperation(captor.getValue()); - verify(sslDriver).initiateClose(); - } - - public void testInitiateCloseFromSameThreadSchedulesCloseNotify() throws SSLException { + public void testInitiateCloseSchedulesCloseNotify() throws SSLException { context.closeChannel(); ArgumentCaptor captor = ArgumentCaptor.forClass(WriteOperation.class); @@ -366,8 +358,15 @@ public class SSLChannelContextTests extends ESTestCase { verify(sslDriver).initiateClose(); } + public void testInitiateUnregisteredScheduledDirectClose() throws SSLException { + context.setSelectionKey(null); + context.closeChannel(); + + verify(selector).queueChannelClose(channel); + } + @SuppressWarnings("unchecked") - public void testRegisterInitiatesDriver() throws IOException { + public void testActiveInitiatesDriver() throws IOException { try (Selector realSelector = Selector.open(); SocketChannel realSocket = SocketChannel.open()) { realSocket.configureBlocking(false); @@ -375,7 +374,7 @@ public class SSLChannelContextTests extends ESTestCase { when(channel.getRawChannel()).thenReturn(realSocket); TestReadWriteHandler readWriteHandler = new TestReadWriteHandler(readConsumer); context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readWriteHandler, channelBuffer); - context.register(); + context.channelActive(); verify(sslDriver).init(); } }