diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/BytesWriteHandler.java b/libs/nio/src/main/java/org/elasticsearch/nio/BytesWriteHandler.java index 48d83d21692..8e590c830b9 100644 --- a/libs/nio/src/main/java/org/elasticsearch/nio/BytesWriteHandler.java +++ b/libs/nio/src/main/java/org/elasticsearch/nio/BytesWriteHandler.java @@ -35,7 +35,7 @@ public abstract class BytesWriteHandler implements NioChannelHandler { } @Override - public void channelRegistered() {} + public void channelActive() {} @Override public List writeToBytes(WriteOperation writeOperation) { diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/ChannelContext.java b/libs/nio/src/main/java/org/elasticsearch/nio/ChannelContext.java index a2663385daa..a030f68fe8b 100644 --- a/libs/nio/src/main/java/org/elasticsearch/nio/ChannelContext.java +++ b/libs/nio/src/main/java/org/elasticsearch/nio/ChannelContext.java @@ -50,17 +50,19 @@ public abstract class ChannelContext context) throws IOException { context.register(); - SelectionKey selectionKey = context.getSelectionKey(); - selectionKey.attach(context); + assert context.getSelectionKey() != null : "SelectionKey should not be null after registration"; + assert context.getSelectionKey().attachment() != null : "Attachment should not be null after registration"; + } + + /** + * This method is called when an attempt to register a channel throws an exception. + * + * @param context that was registered + * @param exception that occurred + */ + protected void registrationException(ChannelContext context, Exception exception) { + context.handleException(exception); + } + + /** + * This method is called after a NioChannel is active with the selector. It should only be called once + * per channel. + * + * @param context that was marked active + */ + protected void handleActive(ChannelContext context) throws IOException { + context.channelActive(); if (context instanceof SocketChannelContext) { if (((SocketChannelContext) context).readyForFlush()) { SelectionKeyUtils.setConnectReadAndWriteInterested(context.getSelectionKey()); @@ -78,12 +98,12 @@ public class EventHandler { } /** - * This method is called when an attempt to register a channel throws an exception. + * This method is called when setting a channel to active throws an exception. * - * @param context that was registered + * @param context that was marked active * @param exception that occurred */ - protected void registrationException(ChannelContext context, Exception exception) { + protected void activeException(ChannelContext context, Exception exception) { context.handleException(exception); } @@ -180,15 +200,9 @@ public class EventHandler { closeException(context, e); } } else { - boolean pendingWrites = context.readyForFlush(); SelectionKey selectionKey = context.getSelectionKey(); - if (selectionKey == null) { - if (pendingWrites) { - writeException(context, new IllegalStateException("Tried to write to an not yet registered channel")); - } - return; - } boolean currentlyWriteInterested = SelectionKeyUtils.isWriteInterested(selectionKey); + boolean pendingWrites = context.readyForFlush(); if (currentlyWriteInterested == false && pendingWrites) { SelectionKeyUtils.setWriteInterested(selectionKey); } else if (currentlyWriteInterested && pendingWrites == false) { diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/NioChannelHandler.java b/libs/nio/src/main/java/org/elasticsearch/nio/NioChannelHandler.java index 61bda9a4507..2d91e769368 100644 --- a/libs/nio/src/main/java/org/elasticsearch/nio/NioChannelHandler.java +++ b/libs/nio/src/main/java/org/elasticsearch/nio/NioChannelHandler.java @@ -29,9 +29,9 @@ import java.util.function.BiConsumer; public interface NioChannelHandler { /** - * This method is called when the channel is registered with its selector. + * This method is called when the channel is active for use. */ - void channelRegistered(); + void channelActive(); /** * This method is called when a message is queued with a channel. It can be called from any thread. diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/NioSelector.java b/libs/nio/src/main/java/org/elasticsearch/nio/NioSelector.java index 175c7661813..cbc069c8f36 100644 --- a/libs/nio/src/main/java/org/elasticsearch/nio/NioSelector.java +++ b/libs/nio/src/main/java/org/elasticsearch/nio/NioSelector.java @@ -340,21 +340,30 @@ public class NioSelector implements Closeable { private void writeToChannel(WriteOperation writeOperation) { assertOnSelectorThread(); SocketChannelContext context = writeOperation.getChannel(); - // If the channel does not currently have anything that is ready to flush, we should flush after - // the write operation is queued. - boolean shouldFlushAfterQueuing = context.readyForFlush() == false; - try { - context.queueWriteOperation(writeOperation); - } catch (Exception e) { - shouldFlushAfterQueuing = false; - executeFailedListener(writeOperation.getListener(), e); - } - if (shouldFlushAfterQueuing) { - if (context.selectorShouldClose() == false) { - handleWrite(context); + if (context.isOpen() == false) { + executeFailedListener(writeOperation.getListener(), new ClosedChannelException()); + } else if (context.getSelectionKey() == null) { + // This should very rarely happen. The only times a channel is exposed outside the event loop, + // but might not registered is through the exception handler and channel accepted callbacks. + executeFailedListener(writeOperation.getListener(), new IllegalStateException("Channel not registered")); + } else { + // If the channel does not currently have anything that is ready to flush, we should flush after + // the write operation is queued. + boolean shouldFlushAfterQueuing = context.readyForFlush() == false; + try { + context.queueWriteOperation(writeOperation); + } catch (Exception e) { + shouldFlushAfterQueuing = false; + executeFailedListener(writeOperation.getListener(), e); + } + + if (shouldFlushAfterQueuing) { + if (context.selectorShouldClose() == false) { + handleWrite(context); + } + eventHandler.postHandling(context); } - eventHandler.postHandling(context); } } @@ -435,14 +444,25 @@ public class NioSelector implements Closeable { try { if (newChannel.isOpen()) { eventHandler.handleRegistration(newChannel); + channelActive(newChannel); if (newChannel instanceof SocketChannelContext) { attemptConnect((SocketChannelContext) newChannel, false); } } else { eventHandler.registrationException(newChannel, new ClosedChannelException()); + closeChannel(newChannel); } } catch (Exception e) { eventHandler.registrationException(newChannel, e); + closeChannel(newChannel); + } + } + + private void channelActive(ChannelContext newChannel) { + try { + eventHandler.handleActive(newChannel); + } catch (IOException e) { + eventHandler.activeException(newChannel, e); } } @@ -464,11 +484,7 @@ public class NioSelector implements Closeable { private void handleQueuedWrites() { WriteOperation writeOperation; while ((writeOperation = queuedWrites.poll()) != null) { - if (writeOperation.getChannel().isOpen()) { - writeToChannel(writeOperation); - } else { - executeFailedListener(writeOperation.getListener(), new ClosedChannelException()); - } + writeToChannel(writeOperation); } } diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java b/libs/nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java index f77ccb17aef..bc93466b58a 100644 --- a/libs/nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java +++ b/libs/nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java @@ -156,9 +156,8 @@ public abstract class SocketChannelContext extends ChannelContext } @Override - protected void register() throws IOException { - super.register(); - readWriteHandler.channelRegistered(); + protected void channelActive() throws IOException { + readWriteHandler.channelActive(); } @Override diff --git a/libs/nio/src/test/java/org/elasticsearch/nio/EventHandlerTests.java b/libs/nio/src/test/java/org/elasticsearch/nio/EventHandlerTests.java index 578890b152f..726d87317ff 100644 --- a/libs/nio/src/test/java/org/elasticsearch/nio/EventHandlerTests.java +++ b/libs/nio/src/test/java/org/elasticsearch/nio/EventHandlerTests.java @@ -81,32 +81,25 @@ public class EventHandlerTests extends ESTestCase { } public void testRegisterCallsContext() throws IOException { - NioSocketChannel channel = mock(NioSocketChannel.class); - SocketChannelContext channelContext = mock(SocketChannelContext.class); - when(channel.getContext()).thenReturn(channelContext); - when(channelContext.getSelectionKey()).thenReturn(new TestSelectionKey(0)); + ChannelContext channelContext = randomBoolean() ? mock(SocketChannelContext.class) : mock(ServerChannelContext.class); + TestSelectionKey attachment = new TestSelectionKey(0); + when(channelContext.getSelectionKey()).thenReturn(attachment); + attachment.attach(channelContext); handler.handleRegistration(channelContext); verify(channelContext).register(); } - public void testRegisterNonServerAddsOP_CONNECTAndOP_READInterest() throws IOException { + public void testActiveNonServerAddsOP_CONNECTAndOP_READInterest() throws IOException { SocketChannelContext context = mock(SocketChannelContext.class); when(context.getSelectionKey()).thenReturn(new TestSelectionKey(0)); - handler.handleRegistration(context); + handler.handleActive(context); assertEquals(SelectionKey.OP_READ | SelectionKey.OP_CONNECT, context.getSelectionKey().interestOps()); } - public void testRegisterAddsAttachment() throws IOException { - ChannelContext context = randomBoolean() ? mock(SocketChannelContext.class) : mock(ServerChannelContext.class); - when(context.getSelectionKey()).thenReturn(new TestSelectionKey(0)); - handler.handleRegistration(context); - assertEquals(context, context.getSelectionKey().attachment()); - } - - public void testHandleServerRegisterSetsOP_ACCEPTInterest() throws IOException { - assertNull(serverContext.getSelectionKey()); - - handler.handleRegistration(serverContext); + public void testHandleServerActiveSetsOP_ACCEPTInterest() throws IOException { + ServerChannelContext serverContext = mock(ServerChannelContext.class); + when(serverContext.getSelectionKey()).thenReturn(new TestSelectionKey(0)); + handler.handleActive(serverContext); assertEquals(SelectionKey.OP_ACCEPT, serverContext.getSelectionKey().interestOps()); } @@ -141,11 +134,11 @@ public class EventHandlerTests extends ESTestCase { verify(serverChannelContext).handleException(exception); } - public void testRegisterWithPendingWritesAddsOP_CONNECTAndOP_READAndOP_WRITEInterest() throws IOException { + public void testActiveWithPendingWritesAddsOP_CONNECTAndOP_READAndOP_WRITEInterest() throws IOException { FlushReadyWrite flushReadyWrite = mock(FlushReadyWrite.class); when(readWriteHandler.writeToBytes(flushReadyWrite)).thenReturn(Collections.singletonList(flushReadyWrite)); context.queueWriteOperation(flushReadyWrite); - handler.handleRegistration(context); + handler.handleActive(context); assertEquals(SelectionKey.OP_READ | SelectionKey.OP_CONNECT | SelectionKey.OP_WRITE, context.getSelectionKey().interestOps()); } @@ -266,7 +259,9 @@ public class EventHandlerTests extends ESTestCase { @Override public void register() { - setSelectionKey(new TestSelectionKey(0)); + TestSelectionKey selectionKey = new TestSelectionKey(0); + setSelectionKey(selectionKey); + selectionKey.attach(this); } } @@ -280,7 +275,9 @@ public class EventHandlerTests extends ESTestCase { @Override public void register() { + TestSelectionKey selectionKey = new TestSelectionKey(0); setSelectionKey(new TestSelectionKey(0)); + selectionKey.attach(this); } } } diff --git a/libs/nio/src/test/java/org/elasticsearch/nio/NioSelectorTests.java b/libs/nio/src/test/java/org/elasticsearch/nio/NioSelectorTests.java index 89a01d02ed4..f7bf4bc24be 100644 --- a/libs/nio/src/test/java/org/elasticsearch/nio/NioSelectorTests.java +++ b/libs/nio/src/test/java/org/elasticsearch/nio/NioSelectorTests.java @@ -212,6 +212,7 @@ public class NioSelectorTests extends ESTestCase { selector.preSelect(); verify(eventHandler).handleRegistration(serverChannelContext); + verify(eventHandler).handleActive(serverChannelContext); } public void testClosedServerChannelWillNotBeRegistered() { @@ -230,7 +231,20 @@ public class NioSelectorTests extends ESTestCase { selector.preSelect(); + verify(eventHandler, times(0)).handleActive(serverChannelContext); verify(eventHandler).registrationException(serverChannelContext, closedChannelException); + verify(eventHandler).handleClose(serverChannelContext); + } + + public void testChannelActiveException() throws Exception { + executeOnNewThread(() -> selector.scheduleForRegistration(serverChannel)); + IOException ioException = new IOException(); + doThrow(ioException).when(eventHandler).handleActive(serverChannelContext); + + selector.preSelect(); + + verify(eventHandler).handleActive(serverChannelContext); + verify(eventHandler).activeException(serverChannelContext, ioException); } public void testClosedSocketChannelWillNotBeRegistered() throws Exception { @@ -241,6 +255,7 @@ public class NioSelectorTests extends ESTestCase { verify(eventHandler).registrationException(same(channelContext), any(ClosedChannelException.class)); verify(eventHandler, times(0)).handleConnect(channelContext); + verify(eventHandler).handleClose(channelContext); } public void testRegisterSocketChannelFailsDueToException() throws InterruptedException { @@ -253,7 +268,9 @@ public class NioSelectorTests extends ESTestCase { selector.preSelect(); verify(eventHandler).registrationException(channelContext, closedChannelException); + verify(eventHandler, times(0)).handleActive(serverChannelContext); verify(eventHandler, times(0)).handleConnect(channelContext); + verify(eventHandler).handleClose(channelContext); }); } @@ -313,6 +330,17 @@ public class NioSelectorTests extends ESTestCase { verify(listener).accept(isNull(Void.class), any(ClosedChannelException.class)); } + public void testQueueWriteChannelIsUnregistered() throws Exception { + WriteOperation writeOperation = new FlushReadyWrite(channelContext, buffers, listener); + + executeOnNewThread(() -> selector.queueWrite(writeOperation)); + when(channelContext.getSelectionKey()).thenReturn(null); + selector.preSelect(); + + verify(channelContext, times(0)).queueWriteOperation(writeOperation); + verify(listener).accept(isNull(Void.class), any(IllegalStateException.class)); + } + public void testQueueWriteSuccessful() throws Exception { WriteOperation writeOperation = new FlushReadyWrite(channelContext, buffers, listener); executeOnNewThread(() -> selector.queueWrite(writeOperation)); diff --git a/libs/nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java b/libs/nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java index 5563ccc4306..210a27aa109 100644 --- a/libs/nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java +++ b/libs/nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java @@ -53,7 +53,7 @@ public class SocketChannelContextTests extends ESTestCase { private NioSocketChannel channel; private BiConsumer listener; private NioSelector selector; - private NioChannelHandler readWriteHandler; + private NioChannelHandler handler; private ByteBuffer ioBuffer = ByteBuffer.allocate(1024); @SuppressWarnings("unchecked") @@ -67,9 +67,9 @@ public class SocketChannelContextTests extends ESTestCase { when(channel.getRawChannel()).thenReturn(rawChannel); exceptionHandler = mock(Consumer.class); selector = mock(NioSelector.class); - readWriteHandler = mock(NioChannelHandler.class); + handler = mock(NioChannelHandler.class); InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance(); - context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, channelBuffer); + context = new TestSocketChannelContext(channel, selector, exceptionHandler, handler, channelBuffer); when(selector.isOnCurrentThread()).thenReturn(true); when(selector.getIoBuffer()).thenAnswer(invocationOnMock -> { @@ -142,6 +142,11 @@ public class SocketChannelContextTests extends ESTestCase { assertSame(ioException, exception.get()); } + public void testChannelActiveCallsHandler() throws IOException { + context.channelActive(); + verify(handler).channelActive(); + } + public void testWriteFailsIfClosing() { context.closeChannel(); @@ -158,7 +163,7 @@ public class SocketChannelContextTests extends ESTestCase { ByteBuffer[] buffers = {ByteBuffer.wrap(createMessage(10))}; WriteOperation writeOperation = mock(WriteOperation.class); - when(readWriteHandler.createWriteOperation(context, buffers, listener)).thenReturn(writeOperation); + when(handler.createWriteOperation(context, buffers, listener)).thenReturn(writeOperation); context.sendMessage(buffers, listener); verify(selector).queueWrite(writeOpCaptor.capture()); @@ -172,7 +177,7 @@ public class SocketChannelContextTests extends ESTestCase { ByteBuffer[] buffers = {ByteBuffer.wrap(createMessage(10))}; WriteOperation writeOperation = mock(WriteOperation.class); - when(readWriteHandler.createWriteOperation(context, buffers, listener)).thenReturn(writeOperation); + when(handler.createWriteOperation(context, buffers, listener)).thenReturn(writeOperation); context.sendMessage(buffers, listener); verify(selector).queueWrite(writeOpCaptor.capture()); @@ -186,16 +191,16 @@ public class SocketChannelContextTests extends ESTestCase { ByteBuffer[] buffer = {ByteBuffer.allocate(10)}; FlushReadyWrite writeOperation = new FlushReadyWrite(context, buffer, listener); - when(readWriteHandler.writeToBytes(writeOperation)).thenReturn(Collections.singletonList(writeOperation)); + when(handler.writeToBytes(writeOperation)).thenReturn(Collections.singletonList(writeOperation)); context.queueWriteOperation(writeOperation); - verify(readWriteHandler).writeToBytes(writeOperation); + verify(handler).writeToBytes(writeOperation); assertTrue(context.readyForFlush()); } public void testHandleReadBytesWillCheckForNewFlushOperations() throws IOException { assertFalse(context.readyForFlush()); - when(readWriteHandler.pollFlushOperations()).thenReturn(Collections.singletonList(mock(FlushOperation.class))); + when(handler.pollFlushOperations()).thenReturn(Collections.singletonList(mock(FlushOperation.class))); context.handleReadBytes(); assertTrue(context.readyForFlush()); } @@ -205,14 +210,14 @@ public class SocketChannelContextTests extends ESTestCase { try (SocketChannel realChannel = SocketChannel.open()) { when(channel.getRawChannel()).thenReturn(realChannel); InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance(); - context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, channelBuffer); + context = new TestSocketChannelContext(channel, selector, exceptionHandler, handler, channelBuffer); assertFalse(context.readyForFlush()); ByteBuffer[] buffer = {ByteBuffer.allocate(10)}; WriteOperation writeOperation = mock(WriteOperation.class); BiConsumer listener2 = mock(BiConsumer.class); - when(readWriteHandler.writeToBytes(writeOperation)).thenReturn(Arrays.asList(new FlushOperation(buffer, listener), + when(handler.writeToBytes(writeOperation)).thenReturn(Arrays.asList(new FlushOperation(buffer, listener), new FlushOperation(buffer, listener2))); context.queueWriteOperation(writeOperation); @@ -233,7 +238,7 @@ public class SocketChannelContextTests extends ESTestCase { try (SocketChannel realChannel = SocketChannel.open()) { when(channel.getRawChannel()).thenReturn(realChannel); InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance(); - context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, channelBuffer); + context = new TestSocketChannelContext(channel, selector, exceptionHandler, handler, channelBuffer); ByteBuffer[] buffer = {ByteBuffer.allocate(10)}; @@ -241,7 +246,7 @@ public class SocketChannelContextTests extends ESTestCase { assertFalse(context.readyForFlush()); when(channel.isOpen()).thenReturn(true); - when(readWriteHandler.pollFlushOperations()).thenReturn(Arrays.asList(new FlushOperation(buffer, listener), + when(handler.pollFlushOperations()).thenReturn(Arrays.asList(new FlushOperation(buffer, listener), new FlushOperation(buffer, listener2))); context.closeFromSelector(); @@ -257,9 +262,9 @@ public class SocketChannelContextTests extends ESTestCase { when(channel.getRawChannel()).thenReturn(realChannel); when(channel.isOpen()).thenReturn(true); InboundChannelBuffer buffer = InboundChannelBuffer.allocatingInstance(); - BytesChannelContext context = new BytesChannelContext(channel, selector, exceptionHandler, readWriteHandler, buffer); + BytesChannelContext context = new BytesChannelContext(channel, selector, exceptionHandler, handler, buffer); context.closeFromSelector(); - verify(readWriteHandler).close(); + verify(handler).close(); } } @@ -271,7 +276,7 @@ public class SocketChannelContextTests extends ESTestCase { IntFunction pageAllocator = (n) -> new Page(ByteBuffer.allocate(n), closer); InboundChannelBuffer buffer = new InboundChannelBuffer(pageAllocator); buffer.ensureCapacity(1); - TestSocketChannelContext context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, buffer); + TestSocketChannelContext context = new TestSocketChannelContext(channel, selector, exceptionHandler, handler, buffer); context.closeFromSelector(); verify(closer).run(); } diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpReadWriteHandler.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpReadWriteHandler.java index c603e20ffc9..9802b24a14a 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpReadWriteHandler.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpReadWriteHandler.java @@ -58,7 +58,7 @@ public class HttpReadWriteHandler implements NioChannelHandler { private final TaskScheduler taskScheduler; private final LongSupplier nanoClock; private final long readTimeoutNanos; - private boolean channelRegistered = false; + private boolean channelActive = false; private boolean requestSinceReadTimeoutTrigger = false; private int inFlightRequests = 0; @@ -91,8 +91,8 @@ public class HttpReadWriteHandler implements NioChannelHandler { } @Override - public void channelRegistered() { - channelRegistered = true; + public void channelActive() { + channelActive = true; if (readTimeoutNanos > 0) { scheduleReadTimeout(); } @@ -100,7 +100,7 @@ public class HttpReadWriteHandler implements NioChannelHandler { @Override public int consumeReads(InboundChannelBuffer channelBuffer) { - assert channelRegistered : "channelRegistered should have been called"; + assert channelActive : "channelActive should have been called"; int bytesConsumed = adaptor.read(channelBuffer.sliceAndRetainPagesTo(channelBuffer.getIndex())); Object message; while ((message = adaptor.pollInboundMessage()) != null) { @@ -123,7 +123,7 @@ public class HttpReadWriteHandler implements NioChannelHandler { public List writeToBytes(WriteOperation writeOperation) { assert writeOperation.getObject() instanceof NioHttpResponse : "This channel only supports messages that are of type: " + NioHttpResponse.class + ". Found type: " + writeOperation.getObject().getClass() + "."; - assert channelRegistered : "channelRegistered should have been called"; + assert channelActive : "channelActive should have been called"; --inFlightRequests; assert inFlightRequests >= 0 : "Inflight requests should never drop below zero, found: " + inFlightRequests; adaptor.write(writeOperation); diff --git a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/HttpReadWriteHandlerTests.java b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/HttpReadWriteHandlerTests.java index 93a846ea90f..253487179b2 100644 --- a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/HttpReadWriteHandlerTests.java +++ b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/HttpReadWriteHandlerTests.java @@ -100,7 +100,7 @@ public class HttpReadWriteHandlerTests extends ESTestCase { NioCorsConfig corsConfig = NioCorsConfigBuilder.forAnyOrigin().build(); handler = new HttpReadWriteHandler(channel, transport, httpHandlingSettings, corsConfig, taskScheduler, System::nanoTime); - handler.channelRegistered(); + handler.channelActive(); } public void testSuccessfulDecodeHttpRequest() throws IOException { @@ -334,7 +334,7 @@ public class HttpReadWriteHandlerTests extends ESTestCase { Iterator timeValues = Arrays.asList(0, 2, 4, 6, 8).iterator(); handler = new HttpReadWriteHandler(channel, transport, httpHandlingSettings, corsConfig, taskScheduler, timeValues::next); - handler.channelRegistered(); + handler.channelActive(); prepareHandlerForResponse(handler); SocketChannelContext context = mock(SocketChannelContext.class); @@ -381,7 +381,7 @@ public class HttpReadWriteHandlerTests extends ESTestCase { NioCorsConfig corsConfig = NioHttpServerTransport.buildCorsConfig(settings); HttpReadWriteHandler handler = new HttpReadWriteHandler(channel, transport, httpSettings, corsConfig, taskScheduler, System::nanoTime); - handler.channelRegistered(); + handler.channelActive(); prepareHandlerForResponse(handler); DefaultFullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); if (originValue != null) { diff --git a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpClient.java b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpClient.java index ed55007f3ba..08fcf8bb44a 100644 --- a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpClient.java +++ b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpClient.java @@ -227,7 +227,7 @@ class NioHttpClient implements Closeable { } @Override - public void channelRegistered() {} + public void channelActive() {} @Override public WriteOperation createWriteOperation(SocketChannelContext context, Object message, BiConsumer listener) { diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/TestEventHandler.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/TestEventHandler.java index 069e19c3455..4a1c6f5deb6 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/TestEventHandler.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/TestEventHandler.java @@ -92,6 +92,30 @@ public class TestEventHandler extends EventHandler { } } + @Override + protected void handleActive(ChannelContext context) throws IOException { + final boolean registered = transportThreadWatchdog.register(); + try { + super.handleActive(context); + } finally { + if (registered) { + transportThreadWatchdog.unregister(); + } + } + } + + @Override + protected void activeException(ChannelContext context, Exception exception) { + final boolean registered = transportThreadWatchdog.register(); + try { + super.activeException(context, exception); + } finally { + if (registered) { + transportThreadWatchdog.unregister(); + } + } + } + public void handleConnect(SocketChannelContext context) throws IOException { assert hasConnectedMap.contains(context) == false : "handleConnect should only be called is a channel is not yet connected"; final boolean registered = transportThreadWatchdog.register(); diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/NioIPFilter.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/NioIPFilter.java index 12f6b67d672..b90a6f4991c 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/NioIPFilter.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/NioIPFilter.java @@ -28,9 +28,9 @@ public final class NioIPFilter extends DelegatingHandler { } @Override - public void channelRegistered() { + public void channelActive() { if (filter.accept(profile, remoteAddress)) { - super.channelRegistered(); + super.channelActive(); } else { denied = true; } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java index 8947447ef58..6a1684dd024 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java @@ -55,8 +55,8 @@ public final class SSLChannelContext extends SocketChannelContext { } @Override - public void register() throws IOException { - super.register(); + protected void channelActive() throws IOException { + super.channelActive(); sslDriver.init(); SSLOutboundBuffer outboundBuffer = sslDriver.getOutboundBuffer(); if (outboundBuffer.hasEncryptedBytesToFlush()) { @@ -179,8 +179,15 @@ public final class SSLChannelContext extends SocketChannelContext { @Override public void closeChannel() { if (isClosing.compareAndSet(false, true)) { - WriteOperation writeOperation = new CloseNotifyOperation(this); - getSelector().queueWrite(writeOperation); + // The model for closing channels will change at some point, removing the need for this "schedule + // a write" signal. But for now, we need to handle the edge case where the channel is not + // registered. + if (getSelectionKey() == null) { + getSelector().queueChannelClose(channel); + } else { + WriteOperation writeOperation = new CloseNotifyOperation(this); + getSelector().queueWrite(writeOperation); + } } } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/NioIPFilterTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/NioIPFilterTests.java index e7612c0c0d7..842c9f031ef 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/NioIPFilterTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/NioIPFilterTests.java @@ -82,8 +82,8 @@ public class NioIPFilterTests extends ESTestCase { InetSocketAddress localhostAddr = new InetSocketAddress(InetAddresses.forString("127.0.0.1"), 12345); NioChannelHandler delegate = mock(NioChannelHandler.class); NioIPFilter nioIPFilter = new NioIPFilter(delegate, localhostAddr, ipFilter, profile); - nioIPFilter.channelRegistered(); - verify(delegate).channelRegistered(); + nioIPFilter.channelActive(); + verify(delegate).channelActive(); assertFalse(nioIPFilter.closeNow()); } @@ -91,8 +91,8 @@ public class NioIPFilterTests extends ESTestCase { InetSocketAddress localhostAddr = new InetSocketAddress(InetAddresses.forString("10.0.0.8"), 12345); NioChannelHandler delegate = mock(NioChannelHandler.class); NioIPFilter nioIPFilter = new NioIPFilter(delegate, localhostAddr, ipFilter, profile); - nioIPFilter.channelRegistered(); - verify(delegate, times(0)).channelRegistered(); + nioIPFilter.channelActive(); + verify(delegate, times(0)).channelActive(); assertTrue(nioIPFilter.closeNow()); } } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java index 7efff1c0e26..8e0a5ad23af 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java @@ -24,6 +24,7 @@ import org.mockito.stubbing.Answer; import javax.net.ssl.SSLException; import java.io.IOException; import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; import java.nio.channels.Selector; import java.nio.channels.SocketChannel; import java.util.function.BiConsumer; @@ -73,6 +74,7 @@ public class SSLChannelContextTests extends ESTestCase { when(channel.getRawChannel()).thenReturn(rawChannel); exceptionHandler = mock(Consumer.class); context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readWriteHandler, channelBuffer); + context.setSelectionKey(mock(SelectionKey.class)); when(selector.isOnCurrentThread()).thenReturn(true); when(selector.getTaskScheduler()).thenReturn(nioTimer); @@ -331,6 +333,7 @@ public class SSLChannelContextTests extends ESTestCase { when(channel.getRawChannel()).thenReturn(realChannel); TestReadWriteHandler readWriteHandler = new TestReadWriteHandler(readConsumer); context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readWriteHandler, channelBuffer); + context.setSelectionKey(mock(SelectionKey.class)); context.closeChannel(); ArgumentCaptor captor = ArgumentCaptor.forClass(WriteOperation.class); verify(selector).queueWrite(captor.capture()); @@ -345,18 +348,7 @@ public class SSLChannelContextTests extends ESTestCase { } } - public void testInitiateCloseFromDifferentThreadSchedulesCloseNotify() throws SSLException { - when(selector.isOnCurrentThread()).thenReturn(false, true); - context.closeChannel(); - - ArgumentCaptor captor = ArgumentCaptor.forClass(FlushReadyWrite.class); - verify(selector).queueWrite(captor.capture()); - - context.queueWriteOperation(captor.getValue()); - verify(sslDriver).initiateClose(); - } - - public void testInitiateCloseFromSameThreadSchedulesCloseNotify() throws SSLException { + public void testInitiateCloseSchedulesCloseNotify() throws SSLException { context.closeChannel(); ArgumentCaptor captor = ArgumentCaptor.forClass(WriteOperation.class); @@ -366,8 +358,15 @@ public class SSLChannelContextTests extends ESTestCase { verify(sslDriver).initiateClose(); } + public void testInitiateUnregisteredScheduledDirectClose() throws SSLException { + context.setSelectionKey(null); + context.closeChannel(); + + verify(selector).queueChannelClose(channel); + } + @SuppressWarnings("unchecked") - public void testRegisterInitiatesDriver() throws IOException { + public void testActiveInitiatesDriver() throws IOException { try (Selector realSelector = Selector.open(); SocketChannel realSocket = SocketChannel.open()) { realSocket.configureBlocking(false); @@ -375,7 +374,7 @@ public class SSLChannelContextTests extends ESTestCase { when(channel.getRawChannel()).thenReturn(realSocket); TestReadWriteHandler readWriteHandler = new TestReadWriteHandler(readConsumer); context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readWriteHandler, channelBuffer); - context.register(); + context.channelActive(); verify(sslDriver).init(); } }