From 5a8ec9b762c68336b68f81a2ec32909030faf090 Mon Sep 17 00:00:00 2001 From: Tim Brooks Date: Thu, 22 Feb 2018 09:44:52 -0700 Subject: [PATCH] Selectors operate on channel contexts (#28468) This commit is related to #27260. Currently there is a weird relationship between channel contexts and nio channels. The selectors use the context for read and writing. But the selector operates directly on the nio channel for registering, closing, and connecting. This commit works on improving this relationship. The selector operates directly on the context which wraps the low level java.nio.channels. The NioChannel class is simply an API that is used to interact with the channel (sending messages from outside the selector event loop, scheduling a close, adding listeners, etc). The context is only used internally by the channel to implement these apis and by the selector to perform these operations. --- .../elasticsearch/nio/AbstractNioChannel.java | 131 ----------- .../elasticsearch/nio/AcceptingSelector.java | 25 +- .../nio/AcceptorEventHandler.java | 35 +-- .../nio/BytesChannelContext.java | 27 +-- .../nio/BytesWriteOperation.java | 10 +- .../org/elasticsearch/nio/ChannelContext.java | 74 +++++- .../org/elasticsearch/nio/ChannelFactory.java | 34 ++- .../org/elasticsearch/nio/ESSelector.java | 23 +- .../org/elasticsearch/nio/EventHandler.java | 20 +- .../org/elasticsearch/nio/NioChannel.java | 49 ++-- .../java/org/elasticsearch/nio/NioGroup.java | 4 +- .../nio/NioServerSocketChannel.java | 22 +- .../elasticsearch/nio/NioSocketChannel.java | 102 +-------- .../elasticsearch/nio/SelectionKeyUtils.java | 51 ++--- .../nio/ServerChannelContext.java | 41 ++-- .../nio/SocketChannelContext.java | 79 ++++++- .../elasticsearch/nio/SocketEventHandler.java | 88 +++---- .../org/elasticsearch/nio/SocketSelector.java | 70 +++--- .../org/elasticsearch/nio/WriteOperation.java | 2 +- .../nio/AcceptingSelectorTests.java | 35 +-- .../nio/AcceptorEventHandlerTests.java | 73 ++++-- .../nio/BytesChannelContextTests.java | 82 ++++--- ...sts.java => BytesWriteOperationTests.java} | 15 +- .../nio/ChannelContextTests.java | 214 ++++++++++++++++++ .../nio/ChannelFactoryTests.java | 37 +-- .../nio/DoNotRegisterChannel.java | 36 --- .../nio/DoNotRegisterServerChannel.java | 37 --- .../elasticsearch/nio/ESSelectorTests.java | 8 +- .../nio/NioServerSocketChannelTests.java | 99 -------- .../nio/NioSocketChannelTests.java | 133 ----------- .../nio/SocketChannelContextTests.java | 173 ++++++++++++++ .../nio/SocketEventHandlerTests.java | 142 +++++++----- .../nio/SocketSelectorTests.java | 117 ++++------ .../transport/nio/NioTransport.java | 17 +- .../nio/TcpNioServerSocketChannel.java | 8 +- .../transport/nio/TcpNioSocketChannel.java | 4 +- .../transport/nio/MockNioTransport.java | 13 +- .../nio/TestingSocketEventHandler.java | 41 ++-- 38 files changed, 1123 insertions(+), 1048 deletions(-) delete mode 100644 libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/AbstractNioChannel.java rename libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/{WriteOperationTests.java => BytesWriteOperationTests.java} (91%) create mode 100644 libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/ChannelContextTests.java delete mode 100644 libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/DoNotRegisterChannel.java delete mode 100644 libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/DoNotRegisterServerChannel.java delete mode 100644 libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/NioServerSocketChannelTests.java delete mode 100644 libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/NioSocketChannelTests.java create mode 100644 libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/AbstractNioChannel.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/AbstractNioChannel.java deleted file mode 100644 index 14e2365eb7e..00000000000 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/AbstractNioChannel.java +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.nio; - -import java.io.IOException; -import java.net.InetSocketAddress; -import java.nio.channels.ClosedChannelException; -import java.nio.channels.NetworkChannel; -import java.nio.channels.SelectableChannel; -import java.nio.channels.SelectionKey; -import java.util.concurrent.CompletableFuture; -import java.util.function.BiConsumer; - -/** - * This is a basic channel abstraction used by the {@link ESSelector}. - *

- * A channel is open once it is constructed. The channel remains open and {@link #isOpen()} will return - * true until the channel is explicitly closed. - *

- * A channel lifecycle has two stages: - *

    - *
  1. OPEN - When a channel has been created. This is the state of a channel that can perform normal operations. - *
  2. CLOSED - The channel has been set to closed. All this means is that the channel has been scheduled to be - * closed. The underlying raw channel may not yet be closed. The underlying channel has been closed if the close - * future has been completed. - *
- * - * @param the type of raw channel this AbstractNioChannel uses - */ -public abstract class AbstractNioChannel implements NioChannel { - - final S socketChannel; - - private final InetSocketAddress localAddress; - private final CompletableFuture closeContext = new CompletableFuture<>(); - private final ESSelector selector; - private SelectionKey selectionKey; - - AbstractNioChannel(S socketChannel, ESSelector selector) throws IOException { - this.socketChannel = socketChannel; - this.localAddress = (InetSocketAddress) socketChannel.getLocalAddress(); - this.selector = selector; - } - - @Override - public boolean isOpen() { - return closeContext.isDone() == false; - } - - @Override - public InetSocketAddress getLocalAddress() { - return localAddress; - } - - /** - * Closes the channel synchronously. This method should only be called from the selector thread. - *

- * Once this method returns, the channel will be closed. - */ - @Override - public void closeFromSelector() throws IOException { - selector.assertOnSelectorThread(); - if (closeContext.isDone() == false) { - try { - socketChannel.close(); - closeContext.complete(null); - } catch (IOException e) { - closeContext.completeExceptionally(e); - throw e; - } - } - } - - /** - * This method attempts to registered a channel with the raw nio selector. It also sets the selection - * key. - * - * @throws ClosedChannelException if the raw channel was closed - */ - @Override - public void register() throws ClosedChannelException { - setSelectionKey(socketChannel.register(selector.rawSelector(), 0)); - } - - @Override - public ESSelector getSelector() { - return selector; - } - - @Override - public SelectionKey getSelectionKey() { - return selectionKey; - } - - @Override - public S getRawChannel() { - return socketChannel; - } - - @Override - public void addCloseListener(BiConsumer listener) { - closeContext.whenComplete(listener); - } - - @Override - public void close() { - getContext().closeChannel(); - } - - // Package visibility for testing - void setSelectionKey(SelectionKey selectionKey) { - this.selectionKey = selectionKey; - } -} diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/AcceptingSelector.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/AcceptingSelector.java index 2cbf7657e5d..da64020daa8 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/AcceptingSelector.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/AcceptingSelector.java @@ -24,6 +24,7 @@ import java.nio.channels.ClosedChannelException; import java.nio.channels.SelectionKey; import java.nio.channels.Selector; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.stream.Collectors; /** * Selector implementation that handles {@link NioServerSocketChannel}. It's main piece of functionality is @@ -46,12 +47,12 @@ public class AcceptingSelector extends ESSelector { @Override void processKey(SelectionKey selectionKey) { - NioServerSocketChannel serverChannel = (NioServerSocketChannel) selectionKey.attachment(); + ServerChannelContext channelContext = (ServerChannelContext) selectionKey.attachment(); if (selectionKey.isAcceptable()) { try { - eventHandler.acceptChannel(serverChannel); + eventHandler.acceptChannel(channelContext); } catch (IOException e) { - eventHandler.acceptException(serverChannel, e); + eventHandler.acceptException(channelContext, e); } } } @@ -63,7 +64,7 @@ public class AcceptingSelector extends ESSelector { @Override void cleanup() { - channelsToClose.addAll(newChannels); + channelsToClose.addAll(newChannels.stream().map(NioServerSocketChannel::getContext).collect(Collectors.toList())); } /** @@ -81,18 +82,16 @@ public class AcceptingSelector extends ESSelector { private void setUpNewServerChannels() { NioServerSocketChannel newChannel; while ((newChannel = this.newChannels.poll()) != null) { - assert newChannel.getSelector() == this : "The channel must be registered with the selector with which it was created"; + ServerChannelContext context = newChannel.getContext(); + assert context.getSelector() == this : "The channel must be registered with the selector with which it was created"; try { - if (newChannel.isOpen()) { - newChannel.register(); - SelectionKey selectionKey = newChannel.getSelectionKey(); - selectionKey.attach(newChannel); - eventHandler.serverChannelRegistered(newChannel); + if (context.isOpen()) { + eventHandler.handleRegistration(context); } else { - eventHandler.registrationException(newChannel, new ClosedChannelException()); + eventHandler.registrationException(context, new ClosedChannelException()); } - } catch (IOException e) { - eventHandler.registrationException(newChannel, e); + } catch (Exception e) { + eventHandler.registrationException(context, e); } } } diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/AcceptorEventHandler.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/AcceptorEventHandler.java index eb5194f21ef..474efad3c77 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/AcceptorEventHandler.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/AcceptorEventHandler.java @@ -23,6 +23,7 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import java.io.IOException; +import java.nio.channels.SelectionKey; import java.util.function.Supplier; /** @@ -38,46 +39,46 @@ public class AcceptorEventHandler extends EventHandler { } /** - * This method is called when a NioServerSocketChannel is successfully registered. It should only be - * called once per channel. + * This method is called when a NioServerSocketChannel is being registered with the selector. It should + * only be called once per channel. * - * @param nioServerSocketChannel that was registered + * @param context that was registered */ - protected void serverChannelRegistered(NioServerSocketChannel nioServerSocketChannel) { - SelectionKeyUtils.setAcceptInterested(nioServerSocketChannel); + protected void handleRegistration(ServerChannelContext context) throws IOException { + context.register(); + SelectionKey selectionKey = context.getSelectionKey(); + selectionKey.attach(context); + SelectionKeyUtils.setAcceptInterested(selectionKey); } /** * This method is called when an attempt to register a server channel throws an exception. * - * @param channel that was registered + * @param context that was registered * @param exception that occurred */ - protected void registrationException(NioServerSocketChannel channel, Exception exception) { - logger.error(new ParameterizedMessage("failed to register server channel: {}", channel), exception); + protected void registrationException(ServerChannelContext context, Exception exception) { + logger.error(new ParameterizedMessage("failed to register server channel: {}", context.getChannel()), exception); } /** * This method is called when a server channel signals it is ready to accept a connection. All of the * accept logic should occur in this call. * - * @param nioServerChannel that can accept a connection + * @param context that can accept a connection */ - protected void acceptChannel(NioServerSocketChannel nioServerChannel) throws IOException { - ChannelFactory channelFactory = nioServerChannel.getChannelFactory(); - SocketSelector selector = selectorSupplier.get(); - NioSocketChannel nioSocketChannel = channelFactory.acceptNioChannel(nioServerChannel, selector); - nioServerChannel.getContext().acceptChannel(nioSocketChannel); + protected void acceptChannel(ServerChannelContext context) throws IOException { + context.acceptChannels(selectorSupplier); } /** * This method is called when an attempt to accept a connection throws an exception. * - * @param nioServerChannel that accepting a connection + * @param context that accepting a connection * @param exception that occurred */ - protected void acceptException(NioServerSocketChannel nioServerChannel, Exception exception) { + protected void acceptException(ServerChannelContext context, Exception exception) { logger.debug(() -> new ParameterizedMessage("exception while accepting new channel from server channel: {}", - nioServerChannel), exception); + context.getChannel()), exception); } } diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesChannelContext.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesChannelContext.java index 5d77675aa48..000e871e927 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesChannelContext.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesChannelContext.java @@ -25,6 +25,7 @@ import java.nio.channels.ClosedChannelException; import java.util.LinkedList; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.BiConsumer; +import java.util.function.Consumer; public class BytesChannelContext extends SocketChannelContext { @@ -33,9 +34,9 @@ public class BytesChannelContext extends SocketChannelContext { private final LinkedList queued = new LinkedList<>(); private final AtomicBoolean isClosing = new AtomicBoolean(false); - public BytesChannelContext(NioSocketChannel channel, BiConsumer exceptionHandler, + public BytesChannelContext(NioSocketChannel channel, SocketSelector selector, Consumer exceptionHandler, ReadConsumer readConsumer, InboundChannelBuffer channelBuffer) { - super(channel, exceptionHandler); + super(channel, selector, exceptionHandler); this.readConsumer = readConsumer; this.channelBuffer = channelBuffer; } @@ -71,8 +72,8 @@ public class BytesChannelContext extends SocketChannelContext { return; } - BytesWriteOperation writeOperation = new BytesWriteOperation(channel, buffers, listener); - SocketSelector selector = channel.getSelector(); + BytesWriteOperation writeOperation = new BytesWriteOperation(this, buffers, listener); + SocketSelector selector = getSelector(); if (selector.isOnCurrentThread() == false) { selector.queueWrite(writeOperation); return; @@ -83,13 +84,13 @@ public class BytesChannelContext extends SocketChannelContext { @Override public void queueWriteOperation(WriteOperation writeOperation) { - channel.getSelector().assertOnSelectorThread(); + getSelector().assertOnSelectorThread(); queued.add((BytesWriteOperation) writeOperation); } @Override public void flushChannel() throws IOException { - channel.getSelector().assertOnSelectorThread(); + getSelector().assertOnSelectorThread(); int ops = queued.size(); if (ops == 1) { singleFlush(queued.pop()); @@ -100,14 +101,14 @@ public class BytesChannelContext extends SocketChannelContext { @Override public boolean hasQueuedWriteOps() { - channel.getSelector().assertOnSelectorThread(); + getSelector().assertOnSelectorThread(); return queued.isEmpty() == false; } @Override public void closeChannel() { if (isClosing.compareAndSet(false, true)) { - channel.getSelector().queueChannelClose(channel); + getSelector().queueChannelClose(channel); } } @@ -118,11 +119,11 @@ public class BytesChannelContext extends SocketChannelContext { @Override public void closeFromSelector() throws IOException { - channel.getSelector().assertOnSelectorThread(); + getSelector().assertOnSelectorThread(); if (channel.isOpen()) { IOException channelCloseException = null; try { - channel.closeFromSelector(); + super.closeFromSelector(); } catch (IOException e) { channelCloseException = e; } @@ -130,7 +131,7 @@ public class BytesChannelContext extends SocketChannelContext { isClosing.set(true); channelBuffer.close(); for (BytesWriteOperation op : queued) { - channel.getSelector().executeFailedListener(op.getListener(), new ClosedChannelException()); + getSelector().executeFailedListener(op.getListener(), new ClosedChannelException()); } queued.clear(); if (channelCloseException != null) { @@ -144,12 +145,12 @@ public class BytesChannelContext extends SocketChannelContext { int written = flushToChannel(headOp.getBuffersToWrite()); headOp.incrementIndex(written); } catch (IOException e) { - channel.getSelector().executeFailedListener(headOp.getListener(), e); + getSelector().executeFailedListener(headOp.getListener(), e); throw e; } if (headOp.isFullyFlushed()) { - channel.getSelector().executeListener(headOp.getListener(), null); + getSelector().executeListener(headOp.getListener(), null); } else { queued.push(headOp); } diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesWriteOperation.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesWriteOperation.java index 14e8cace66d..37c6e497276 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesWriteOperation.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesWriteOperation.java @@ -25,15 +25,15 @@ import java.util.function.BiConsumer; public class BytesWriteOperation implements WriteOperation { - private final NioSocketChannel channel; + private final SocketChannelContext channelContext; private final BiConsumer listener; private final ByteBuffer[] buffers; private final int[] offsets; private final int length; private int internalIndex; - public BytesWriteOperation(NioSocketChannel channel, ByteBuffer[] buffers, BiConsumer listener) { - this.channel = channel; + public BytesWriteOperation(SocketChannelContext channelContext, ByteBuffer[] buffers, BiConsumer listener) { + this.channelContext = channelContext; this.listener = listener; this.buffers = buffers; this.offsets = new int[buffers.length]; @@ -52,8 +52,8 @@ public class BytesWriteOperation implements WriteOperation { } @Override - public NioSocketChannel getChannel() { - return channel; + public SocketChannelContext getChannel() { + return channelContext; } public boolean isFullyFlushed() { diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ChannelContext.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ChannelContext.java index fa664484c1c..01f35347aa4 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ChannelContext.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ChannelContext.java @@ -20,15 +20,78 @@ package org.elasticsearch.nio; import java.io.IOException; +import java.nio.channels.NetworkChannel; +import java.nio.channels.SelectableChannel; +import java.nio.channels.SelectionKey; +import java.util.concurrent.CompletableFuture; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +/** + * Implements the logic related to interacting with a java.nio channel. For example: registering with a + * selector, managing the selection key, closing, etc is implemented by this class or its subclasses. + * + * @param the type of channel + */ +public abstract class ChannelContext { + + protected final S rawChannel; + private final Consumer exceptionHandler; + private final CompletableFuture closeContext = new CompletableFuture<>(); + private volatile SelectionKey selectionKey; + + ChannelContext(S rawChannel, Consumer exceptionHandler) { + this.rawChannel = rawChannel; + this.exceptionHandler = exceptionHandler; + } + + protected void register() throws IOException { + setSelectionKey(rawChannel.register(getSelector().rawSelector(), 0)); + } + + SelectionKey getSelectionKey() { + return selectionKey; + } + + // Protected for tests + protected void setSelectionKey(SelectionKey selectionKey) { + this.selectionKey = selectionKey; + } -public interface ChannelContext { /** * This method cleans up any context resources that need to be released when a channel is closed. It * should only be called by the selector thread. * * @throws IOException during channel / context close */ - void closeFromSelector() throws IOException; + public void closeFromSelector() throws IOException { + if (closeContext.isDone() == false) { + try { + rawChannel.close(); + closeContext.complete(null); + } catch (Exception e) { + closeContext.completeExceptionally(e); + throw e; + } + } + } + + /** + * Add a listener that will be called when the channel is closed. + * + * @param listener to be called + */ + public void addCloseListener(BiConsumer listener) { + closeContext.whenComplete(listener); + } + + public boolean isOpen() { + return closeContext.isDone() == false; + } + + void handleException(Exception e) { + exceptionHandler.accept(e); + } /** * Schedules a channel to be closed by the selector event loop with which it is registered. @@ -39,7 +102,10 @@ public interface ChannelContext { * Depending on the underlying protocol of the channel, a close operation might simply close the socket * channel or may involve reading and writing messages. */ - void closeChannel(); + public abstract void closeChannel(); + + public abstract ESSelector getSelector(); + + public abstract NioChannel getChannel(); - void handleException(Exception e); } diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ChannelFactory.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ChannelFactory.java index 5fc3f46f998..a03c4bcc15b 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ChannelFactory.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ChannelFactory.java @@ -27,6 +27,7 @@ import java.nio.channels.SocketChannel; import java.security.AccessController; import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; +import java.util.function.Supplier; public abstract class ChannelFactory { @@ -41,22 +42,30 @@ public abstract class ChannelFactory supplier) throws IOException { SocketChannel rawChannel = rawChannelFactory.openNioChannel(remoteAddress); + SocketSelector selector = supplier.get(); Socket channel = internalCreateChannel(selector, rawChannel); scheduleChannel(channel, selector); return channel; } - public Socket acceptNioChannel(NioServerSocketChannel serverChannel, SocketSelector selector) throws IOException { - SocketChannel rawChannel = rawChannelFactory.acceptNioChannel(serverChannel); - Socket channel = internalCreateChannel(selector, rawChannel); - scheduleChannel(channel, selector); - return channel; + public Socket acceptNioChannel(ServerChannelContext serverContext, Supplier supplier) throws IOException { + SocketChannel rawChannel = rawChannelFactory.acceptNioChannel(serverContext); + // Null is returned if there are no pending sockets to accept + if (rawChannel == null) { + return null; + } else { + SocketSelector selector = supplier.get(); + Socket channel = internalCreateChannel(selector, rawChannel); + scheduleChannel(channel, selector); + return channel; + } } - public ServerSocket openNioServerSocketChannel(InetSocketAddress address, AcceptingSelector selector) throws IOException { + public ServerSocket openNioServerSocketChannel(InetSocketAddress address, Supplier supplier) throws IOException { ServerSocketChannel rawChannel = rawChannelFactory.openNioServerSocketChannel(address); + AcceptingSelector selector = supplier.get(); ServerSocket serverChannel = internalCreateServerChannel(selector, rawChannel); scheduleServerChannel(serverChannel, selector); return serverChannel; @@ -140,7 +149,7 @@ public abstract class ChannelFactory channelsToClose = new ConcurrentLinkedQueue<>(); + final ConcurrentLinkedQueue> channelsToClose = new ConcurrentLinkedQueue<>(); private final EventHandler eventHandler; private final ReentrantLock runLock = new ReentrantLock(); @@ -60,7 +60,7 @@ public abstract class ESSelector implements Closeable { this(eventHandler, Selector.open()); } - ESSelector(EventHandler eventHandler, Selector selector) throws IOException { + ESSelector(EventHandler eventHandler, Selector selector) { this.eventHandler = eventHandler; this.selector = selector; } @@ -111,10 +111,10 @@ public abstract class ESSelector implements Closeable { try { processKey(sk); } catch (CancelledKeyException cke) { - eventHandler.genericChannelException((NioChannel) sk.attachment(), cke); + eventHandler.genericChannelException((ChannelContext) sk.attachment(), cke); } } else { - eventHandler.genericChannelException((NioChannel) sk.attachment(), new CancelledKeyException()); + eventHandler.genericChannelException((ChannelContext) sk.attachment(), new CancelledKeyException()); } } } @@ -131,7 +131,7 @@ public abstract class ESSelector implements Closeable { void cleanupAndCloseChannels() { cleanup(); - channelsToClose.addAll(selector.keys().stream().map(sk -> (NioChannel) sk.attachment()).collect(Collectors.toList())); + channelsToClose.addAll(selector.keys().stream().map(sk -> (ChannelContext) sk.attachment()).collect(Collectors.toList())); closePendingChannels(); } @@ -191,9 +191,10 @@ public abstract class ESSelector implements Closeable { } public void queueChannelClose(NioChannel channel) { - assert channel.getSelector() == this : "Must schedule a channel for closure with its selector"; - channelsToClose.offer(channel); - ensureSelectorOpenForEnqueuing(channelsToClose, channel); + ChannelContext context = channel.getContext(); + assert context.getSelector() == this : "Must schedule a channel for closure with its selector"; + channelsToClose.offer(context); + ensureSelectorOpenForEnqueuing(channelsToClose, context); wakeup(); } @@ -239,9 +240,9 @@ public abstract class ESSelector implements Closeable { } private void closePendingChannels() { - NioChannel channel; - while ((channel = channelsToClose.poll()) != null) { - eventHandler.handleClose(channel); + ChannelContext channelContext; + while ((channelContext = channelsToClose.poll()) != null) { + eventHandler.handleClose(channelContext); } } } diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/EventHandler.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/EventHandler.java index 7cba9b998b3..d35b73c56b8 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/EventHandler.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/EventHandler.java @@ -65,25 +65,25 @@ public abstract class EventHandler { /** * This method handles the closing of an NioChannel * - * @param channel that should be closed + * @param context that should be closed */ - protected void handleClose(NioChannel channel) { + protected void handleClose(ChannelContext context) { try { - channel.getContext().closeFromSelector(); + context.closeFromSelector(); } catch (IOException e) { - closeException(channel, e); + closeException(context, e); } - assert channel.isOpen() == false : "Should always be done as we are on the selector thread"; + assert context.isOpen() == false : "Should always be done as we are on the selector thread"; } /** * This method is called when an attempt to close a channel throws an exception. * - * @param channel that was being closed + * @param context that was being closed * @param exception that occurred */ - protected void closeException(NioChannel channel, Exception exception) { - logger.debug(() -> new ParameterizedMessage("exception while closing channel: {}", channel), exception); + protected void closeException(ChannelContext context, Exception exception) { + logger.debug(() -> new ParameterizedMessage("exception while closing channel: {}", context.getChannel()), exception); } /** @@ -94,7 +94,7 @@ public abstract class EventHandler { * @param channel that caused the exception * @param exception that was thrown */ - protected void genericChannelException(NioChannel channel, Exception exception) { - logger.debug(() -> new ParameterizedMessage("exception while handling event for channel: {}", channel), exception); + protected void genericChannelException(ChannelContext channel, Exception exception) { + logger.debug(() -> new ParameterizedMessage("exception while handling event for channel: {}", channel.getChannel()), exception); } } diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/NioChannel.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/NioChannel.java index 690e3d3b38b..2f9705f5f8f 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/NioChannel.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/NioChannel.java @@ -21,30 +21,30 @@ package org.elasticsearch.nio; import java.io.IOException; import java.net.InetSocketAddress; -import java.nio.channels.ClosedChannelException; import java.nio.channels.NetworkChannel; -import java.nio.channels.SelectionKey; import java.util.function.BiConsumer; -public interface NioChannel { +/** + * This is a basic channel abstraction used by the {@link ESSelector}. + *

+ * A channel is open once it is constructed. The channel remains open and {@link #isOpen()} will return + * true until the channel is explicitly closed. + */ +public abstract class NioChannel { - boolean isOpen(); + private final InetSocketAddress localAddress; - InetSocketAddress getLocalAddress(); + NioChannel(NetworkChannel socketChannel) throws IOException { + this.localAddress = (InetSocketAddress) socketChannel.getLocalAddress(); + } - void close(); + public boolean isOpen() { + return getContext().isOpen(); + } - void closeFromSelector() throws IOException; - - void register() throws ClosedChannelException; - - ESSelector getSelector(); - - SelectionKey getSelectionKey(); - - NetworkChannel getRawChannel(); - - ChannelContext getContext(); + public InetSocketAddress getLocalAddress() { + return localAddress; + } /** * Adds a close listener to the channel. Multiple close listeners can be added. There is no guarantee @@ -53,5 +53,18 @@ public interface NioChannel { * * @param listener to be called at close */ - void addCloseListener(BiConsumer listener); + public void addCloseListener(BiConsumer listener) { + getContext().addCloseListener(listener); + } + + /** + * Schedules channel for close. This process is asynchronous. + */ + public void close() { + getContext().closeChannel(); + } + + public abstract NetworkChannel getRawChannel(); + + public abstract ChannelContext getContext(); } diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/NioGroup.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/NioGroup.java index 109d22c45fd..b7637656162 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/NioGroup.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/NioGroup.java @@ -96,12 +96,12 @@ public class NioGroup implements AutoCloseable { if (acceptors.isEmpty()) { throw new IllegalArgumentException("There are no acceptors configured. Without acceptors, server channels are not supported."); } - return factory.openNioServerSocketChannel(address, acceptorSupplier.get()); + return factory.openNioServerSocketChannel(address, acceptorSupplier); } public S openChannel(InetSocketAddress address, ChannelFactory factory) throws IOException { ensureOpen(); - return factory.openNioChannel(address, socketSelectorSupplier.get()); + return factory.openNioChannel(address, socketSelectorSupplier); } @Override diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/NioServerSocketChannel.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/NioServerSocketChannel.java index 3d1748e413a..9f78c3b1b31 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/NioServerSocketChannel.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/NioServerSocketChannel.java @@ -23,20 +23,15 @@ import java.io.IOException; import java.nio.channels.ServerSocketChannel; import java.util.concurrent.atomic.AtomicBoolean; -public class NioServerSocketChannel extends AbstractNioChannel { +public class NioServerSocketChannel extends NioChannel { - private final ChannelFactory channelFactory; - private ServerChannelContext context; + private final ServerSocketChannel socketChannel; private final AtomicBoolean contextSet = new AtomicBoolean(false); + private ServerChannelContext context; - public NioServerSocketChannel(ServerSocketChannel socketChannel, ChannelFactory channelFactory, AcceptingSelector selector) - throws IOException { - super(socketChannel, selector); - this.channelFactory = channelFactory; - } - - public ChannelFactory getChannelFactory() { - return channelFactory; + public NioServerSocketChannel(ServerSocketChannel socketChannel) throws IOException { + super(socketChannel); + this.socketChannel = socketChannel; } /** @@ -53,6 +48,11 @@ public class NioServerSocketChannel extends AbstractNioChannel { +public class NioSocketChannel extends NioChannel { private final InetSocketAddress remoteAddress; - private final CompletableFuture connectContext = new CompletableFuture<>(); - private final SocketSelector socketSelector; private final AtomicBoolean contextSet = new AtomicBoolean(false); + private final SocketChannel socketChannel; private SocketChannelContext context; - private Exception connectException; - public NioSocketChannel(SocketChannel socketChannel, SocketSelector selector) throws IOException { - super(socketChannel, selector); + public NioSocketChannel(SocketChannel socketChannel) throws IOException { + super(socketChannel); + this.socketChannel = socketChannel; this.remoteAddress = (InetSocketAddress) socketChannel.getRemoteAddress(); - this.socketSelector = selector; - } - - @Override - public SocketSelector getSelector() { - return socketSelector; - } - - public int write(ByteBuffer buffer) throws IOException { - return socketChannel.write(buffer); - } - - public int write(ByteBuffer[] buffers) throws IOException { - if (buffers.length == 1) { - return socketChannel.write(buffers[0]); - } else { - return (int) socketChannel.write(buffers); - } - } - - public int read(ByteBuffer buffer) throws IOException { - return socketChannel.read(buffer); - } - - public int read(ByteBuffer[] buffers) throws IOException { - if (buffers.length == 1) { - return socketChannel.read(buffers[0]); - } else { - return (int) socketChannel.read(buffers); - } } public void setContext(SocketChannelContext context) { @@ -79,6 +46,11 @@ public class NioSocketChannel extends AbstractNioChannel { } } + @Override + public SocketChannel getRawChannel() { + return socketChannel; + } + @Override public SocketChannelContext getContext() { return context; @@ -88,46 +60,8 @@ public class NioSocketChannel extends AbstractNioChannel { return remoteAddress; } - public boolean isConnectComplete() { - return isConnectComplete0(); - } - - /** - * This method will attempt to complete the connection process for this channel. It should be called for - * new channels or for a channel that has produced a OP_CONNECT event. If this method returns true then - * the connection is complete and the channel is ready for reads and writes. If it returns false, the - * channel is not yet connected and this method should be called again when a OP_CONNECT event is - * received. - * - * @return true if the connection process is complete - * @throws IOException if an I/O error occurs - */ - public boolean finishConnect() throws IOException { - if (isConnectComplete0()) { - return true; - } else if (connectContext.isCompletedExceptionally()) { - Exception exception = connectException; - if (exception == null) { - throw new AssertionError("Should have received connection exception"); - } else if (exception instanceof IOException) { - throw (IOException) exception; - } else { - throw (RuntimeException) exception; - } - } - - boolean isConnected = socketChannel.isConnected(); - if (isConnected == false) { - isConnected = internalFinish(); - } - if (isConnected) { - connectContext.complete(null); - } - return isConnected; - } - public void addConnectListener(BiConsumer listener) { - connectContext.whenComplete(listener); + context.addConnectListener(listener); } @Override @@ -137,18 +71,4 @@ public class NioSocketChannel extends AbstractNioChannel { ", remoteAddress=" + remoteAddress + '}'; } - - private boolean internalFinish() throws IOException { - try { - return socketChannel.finishConnect(); - } catch (IOException | RuntimeException e) { - connectException = e; - connectContext.completeExceptionally(e); - throw e; - } - } - - private boolean isConnectComplete0() { - return connectContext.isDone() && connectContext.isCompletedExceptionally() == false; - } } diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SelectionKeyUtils.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SelectionKeyUtils.java index be2dc6f3414..93d58344de6 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SelectionKeyUtils.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SelectionKeyUtils.java @@ -27,80 +27,75 @@ public final class SelectionKeyUtils { private SelectionKeyUtils() {} /** - * Adds an interest in writes for this channel while maintaining other interests. + * Adds an interest in writes for this selection key while maintaining other interests. * - * @param channel the channel + * @param selectionKey the selection key * @throws CancelledKeyException if the key was already cancelled */ - public static void setWriteInterested(NioChannel channel) throws CancelledKeyException { - SelectionKey selectionKey = channel.getSelectionKey(); + public static void setWriteInterested(SelectionKey selectionKey) throws CancelledKeyException { selectionKey.interestOps(selectionKey.interestOps() | SelectionKey.OP_WRITE); } /** - * Removes an interest in writes for this channel while maintaining other interests. + * Removes an interest in writes for this selection key while maintaining other interests. * - * @param channel the channel + * @param selectionKey the selection key * @throws CancelledKeyException if the key was already cancelled */ - public static void removeWriteInterested(NioChannel channel) throws CancelledKeyException { - SelectionKey selectionKey = channel.getSelectionKey(); + public static void removeWriteInterested(SelectionKey selectionKey) throws CancelledKeyException { selectionKey.interestOps(selectionKey.interestOps() & ~SelectionKey.OP_WRITE); } /** - * Removes an interest in connects and reads for this channel while maintaining other interests. + * Removes an interest in connects and reads for this selection key while maintaining other interests. * - * @param channel the channel + * @param selectionKey the selection key * @throws CancelledKeyException if the key was already cancelled */ - public static void setConnectAndReadInterested(NioChannel channel) throws CancelledKeyException { - SelectionKey selectionKey = channel.getSelectionKey(); + public static void setConnectAndReadInterested(SelectionKey selectionKey) throws CancelledKeyException { selectionKey.interestOps(selectionKey.interestOps() | SelectionKey.OP_CONNECT | SelectionKey.OP_READ); } /** - * Removes an interest in connects, reads, and writes for this channel while maintaining other interests. + * Removes an interest in connects, reads, and writes for this selection key while maintaining other + * interests. * - * @param channel the channel + * @param selectionKey the selection key * @throws CancelledKeyException if the key was already cancelled */ - public static void setConnectReadAndWriteInterested(NioChannel channel) throws CancelledKeyException { - SelectionKey selectionKey = channel.getSelectionKey(); + public static void setConnectReadAndWriteInterested(SelectionKey selectionKey) throws CancelledKeyException { selectionKey.interestOps(selectionKey.interestOps() | SelectionKey.OP_CONNECT | SelectionKey.OP_READ | SelectionKey.OP_WRITE); } /** - * Removes an interest in connects for this channel while maintaining other interests. + * Removes an interest in connects for this selection key while maintaining other interests. * - * @param channel the channel + * @param selectionKey the selection key * @throws CancelledKeyException if the key was already cancelled */ - public static void removeConnectInterested(NioChannel channel) throws CancelledKeyException { - SelectionKey selectionKey = channel.getSelectionKey(); + public static void removeConnectInterested(SelectionKey selectionKey) throws CancelledKeyException { selectionKey.interestOps(selectionKey.interestOps() & ~SelectionKey.OP_CONNECT); } /** - * Adds an interest in accepts for this channel while maintaining other interests. + * Adds an interest in accepts for this selection key while maintaining other interests. * - * @param channel the channel + * @param selectionKey the selection key * @throws CancelledKeyException if the key was already cancelled */ - public static void setAcceptInterested(NioServerSocketChannel channel) throws CancelledKeyException { - SelectionKey selectionKey = channel.getSelectionKey(); + public static void setAcceptInterested(SelectionKey selectionKey) throws CancelledKeyException { selectionKey.interestOps(selectionKey.interestOps() | SelectionKey.OP_ACCEPT); } /** - * Checks for an interest in writes for this channel. + * Checks for an interest in writes for this selection key. * - * @param channel the channel + * @param selectionKey the selection key * @return a boolean indicating if we are currently interested in writes for this channel * @throws CancelledKeyException if the key was already cancelled */ - public static boolean isWriteInterested(NioSocketChannel channel) throws CancelledKeyException { - return (channel.getSelectionKey().interestOps() & SelectionKey.OP_WRITE) != 0; + public static boolean isWriteInterested(SelectionKey selectionKey) throws CancelledKeyException { + return (selectionKey.interestOps() & SelectionKey.OP_WRITE) != 0; } } diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ServerChannelContext.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ServerChannelContext.java index 551cab48e05..4b47ce063f9 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ServerChannelContext.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ServerChannelContext.java @@ -20,43 +20,50 @@ package org.elasticsearch.nio; import java.io.IOException; +import java.nio.channels.ServerSocketChannel; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.BiConsumer; import java.util.function.Consumer; +import java.util.function.Supplier; -public class ServerChannelContext implements ChannelContext { +public class ServerChannelContext extends ChannelContext { private final NioServerSocketChannel channel; + private final AcceptingSelector selector; private final Consumer acceptor; - private final BiConsumer exceptionHandler; private final AtomicBoolean isClosing = new AtomicBoolean(false); + private final ChannelFactory channelFactory; - public ServerChannelContext(NioServerSocketChannel channel, Consumer acceptor, - BiConsumer exceptionHandler) { + public ServerChannelContext(NioServerSocketChannel channel, ChannelFactory channelFactory, AcceptingSelector selector, + Consumer acceptor, Consumer exceptionHandler) { + super(channel.getRawChannel(), exceptionHandler); this.channel = channel; + this.channelFactory = channelFactory; + this.selector = selector; this.acceptor = acceptor; - this.exceptionHandler = exceptionHandler; } - public void acceptChannel(NioSocketChannel acceptedChannel) { - acceptor.accept(acceptedChannel); - } - - @Override - public void closeFromSelector() throws IOException { - channel.closeFromSelector(); + public void acceptChannels(Supplier selectorSupplier) throws IOException { + NioSocketChannel acceptedChannel; + while ((acceptedChannel = channelFactory.acceptNioChannel(this, selectorSupplier)) != null) { + acceptor.accept(acceptedChannel); + } } @Override public void closeChannel() { if (isClosing.compareAndSet(false, true)) { - channel.getSelector().queueChannelClose(channel); + getSelector().queueChannelClose(channel); } } @Override - public void handleException(Exception e) { - exceptionHandler.accept(channel, e); + public AcceptingSelector getSelector() { + return selector; } + + @Override + public NioServerSocketChannel getChannel() { + return channel; + } + } diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java index 62f82e8995d..3bf47a98e02 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java @@ -21,7 +21,10 @@ package org.elasticsearch.nio; import java.io.IOException; import java.nio.ByteBuffer; +import java.nio.channels.SocketChannel; +import java.util.concurrent.CompletableFuture; import java.util.function.BiConsumer; +import java.util.function.Consumer; /** * This context should implement the specific logic for a channel. When a channel receives a notification @@ -32,24 +35,78 @@ import java.util.function.BiConsumer; * The only methods of the context that should ever be called from a non-selector thread are * {@link #closeChannel()} and {@link #sendMessage(ByteBuffer[], BiConsumer)}. */ -public abstract class SocketChannelContext implements ChannelContext { +public abstract class SocketChannelContext extends ChannelContext { protected final NioSocketChannel channel; - private final BiConsumer exceptionHandler; + private final SocketSelector selector; + private final CompletableFuture connectContext = new CompletableFuture<>(); private boolean ioException; private boolean peerClosed; + private Exception connectException; - protected SocketChannelContext(NioSocketChannel channel, BiConsumer exceptionHandler) { + protected SocketChannelContext(NioSocketChannel channel, SocketSelector selector, Consumer exceptionHandler) { + super(channel.getRawChannel(), exceptionHandler); + this.selector = selector; this.channel = channel; - this.exceptionHandler = exceptionHandler; } @Override - public void handleException(Exception e) { - exceptionHandler.accept(channel, e); + public SocketSelector getSelector() { + return selector; } - public void channelRegistered() throws IOException {} + @Override + public NioSocketChannel getChannel() { + return channel; + } + + public void addConnectListener(BiConsumer listener) { + connectContext.whenComplete(listener); + } + + public boolean isConnectComplete() { + return connectContext.isDone() && connectContext.isCompletedExceptionally() == false; + } + + /** + * This method will attempt to complete the connection process for this channel. It should be called for + * new channels or for a channel that has produced a OP_CONNECT event. If this method returns true then + * the connection is complete and the channel is ready for reads and writes. If it returns false, the + * channel is not yet connected and this method should be called again when a OP_CONNECT event is + * received. + * + * @return true if the connection process is complete + * @throws IOException if an I/O error occurs + */ + public boolean connect() throws IOException { + if (isConnectComplete()) { + return true; + } else if (connectContext.isCompletedExceptionally()) { + Exception exception = connectException; + if (exception == null) { + throw new AssertionError("Should have received connection exception"); + } else if (exception instanceof IOException) { + throw (IOException) exception; + } else { + throw (RuntimeException) exception; + } + } + + boolean isConnected = rawChannel.isConnected(); + if (isConnected == false) { + try { + isConnected = rawChannel.finishConnect(); + } catch (IOException | RuntimeException e) { + connectException = e; + connectContext.completeExceptionally(e); + throw e; + } + } + if (isConnected) { + connectContext.complete(null); + } + return isConnected; + } public abstract int read() throws IOException; @@ -78,7 +135,7 @@ public abstract class SocketChannelContext implements ChannelContext { protected int readFromChannel(ByteBuffer buffer) throws IOException { try { - int bytesRead = channel.read(buffer); + int bytesRead = rawChannel.read(buffer); if (bytesRead < 0) { peerClosed = true; bytesRead = 0; @@ -92,7 +149,7 @@ public abstract class SocketChannelContext implements ChannelContext { protected int readFromChannel(ByteBuffer[] buffers) throws IOException { try { - int bytesRead = channel.read(buffers); + int bytesRead = (int) rawChannel.read(buffers); if (bytesRead < 0) { peerClosed = true; bytesRead = 0; @@ -106,7 +163,7 @@ public abstract class SocketChannelContext implements ChannelContext { protected int flushToChannel(ByteBuffer buffer) throws IOException { try { - return channel.write(buffer); + return rawChannel.write(buffer); } catch (IOException e) { ioException = true; throw e; @@ -115,7 +172,7 @@ public abstract class SocketChannelContext implements ChannelContext { protected int flushToChannel(ByteBuffer[] buffers) throws IOException { try { - return channel.write(buffers); + return (int) rawChannel.write(buffers); } catch (IOException e) { ioException = true; throw e; diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketEventHandler.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketEventHandler.java index b1192f11eb1..b1f73864761 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketEventHandler.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketEventHandler.java @@ -23,6 +23,7 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import java.io.IOException; +import java.nio.channels.SelectionKey; import java.util.function.BiConsumer; /** @@ -41,91 +42,93 @@ public class SocketEventHandler extends EventHandler { * This method is called when a NioSocketChannel is successfully registered. It should only be called * once per channel. * - * @param channel that was registered + * @param context that was registered */ - protected void handleRegistration(NioSocketChannel channel) throws IOException { - SocketChannelContext context = channel.getContext(); - context.channelRegistered(); + protected void handleRegistration(SocketChannelContext context) throws IOException { + context.register(); + SelectionKey selectionKey = context.getSelectionKey(); + selectionKey.attach(context); if (context.hasQueuedWriteOps()) { - SelectionKeyUtils.setConnectReadAndWriteInterested(channel); + SelectionKeyUtils.setConnectReadAndWriteInterested(selectionKey); } else { - SelectionKeyUtils.setConnectAndReadInterested(channel); + SelectionKeyUtils.setConnectAndReadInterested(selectionKey); } } /** * This method is called when an attempt to register a channel throws an exception. * - * @param channel that was registered + * @param context that was registered * @param exception that occurred */ - protected void registrationException(NioSocketChannel channel, Exception exception) { - logger.debug(() -> new ParameterizedMessage("failed to register socket channel: {}", channel), exception); - channel.getContext().handleException(exception); + protected void registrationException(SocketChannelContext context, Exception exception) { + logger.debug(() -> new ParameterizedMessage("failed to register socket channel: {}", context.getChannel()), exception); + context.handleException(exception); } /** - * This method is called when a NioSocketChannel is successfully connected. It should only be called - * once per channel. + * This method is called when a NioSocketChannel has just been accepted or if it has receive an + * OP_CONNECT event. * - * @param channel that was registered + * @param context that was registered */ - protected void handleConnect(NioSocketChannel channel) { - SelectionKeyUtils.removeConnectInterested(channel); + protected void handleConnect(SocketChannelContext context) throws IOException { + if (context.connect()) { + SelectionKeyUtils.removeConnectInterested(context.getSelectionKey()); + } } /** * This method is called when an attempt to connect a channel throws an exception. * - * @param channel that was connecting + * @param context that was connecting * @param exception that occurred */ - protected void connectException(NioSocketChannel channel, Exception exception) { - logger.debug(() -> new ParameterizedMessage("failed to connect to socket channel: {}", channel), exception); - channel.getContext().handleException(exception); + protected void connectException(SocketChannelContext context, Exception exception) { + logger.debug(() -> new ParameterizedMessage("failed to connect to socket channel: {}", context.getChannel()), exception); + context.handleException(exception); } /** * This method is called when a channel signals it is ready for be read. All of the read logic should * occur in this call. * - * @param channel that can be read + * @param context that can be read */ - protected void handleRead(NioSocketChannel channel) throws IOException { - channel.getContext().read(); + protected void handleRead(SocketChannelContext context) throws IOException { + context.read(); } /** * This method is called when an attempt to read from a channel throws an exception. * - * @param channel that was being read + * @param context that was being read * @param exception that occurred */ - protected void readException(NioSocketChannel channel, Exception exception) { - logger.debug(() -> new ParameterizedMessage("exception while reading from socket channel: {}", channel), exception); - channel.getContext().handleException(exception); + protected void readException(SocketChannelContext context, Exception exception) { + logger.debug(() -> new ParameterizedMessage("exception while reading from socket channel: {}", context.getChannel()), exception); + context.handleException(exception); } /** * This method is called when a channel signals it is ready to receive writes. All of the write logic * should occur in this call. * - * @param channel that can be written to + * @param context that can be written to */ - protected void handleWrite(NioSocketChannel channel) throws IOException { - SocketChannelContext channelContext = channel.getContext(); - channelContext.flushChannel(); + protected void handleWrite(SocketChannelContext context) throws IOException { + context.flushChannel(); } /** * This method is called when an attempt to write to a channel throws an exception. * - * @param channel that was being written to + * @param context that was being written to * @param exception that occurred */ - protected void writeException(NioSocketChannel channel, Exception exception) { - logger.debug(() -> new ParameterizedMessage("exception while writing to socket channel: {}", channel), exception); - channel.getContext().handleException(exception); + protected void writeException(SocketChannelContext context, Exception exception) { + logger.debug(() -> new ParameterizedMessage("exception while writing to socket channel: {}", context.getChannel()), exception); + context.handleException(exception); } /** @@ -139,18 +142,19 @@ public class SocketEventHandler extends EventHandler { } /** - * @param channel that was handled + * @param context that was handled */ - protected void postHandling(NioSocketChannel channel) { - if (channel.getContext().selectorShouldClose()) { - handleClose(channel); + protected void postHandling(SocketChannelContext context) { + if (context.selectorShouldClose()) { + handleClose(context); } else { - boolean currentlyWriteInterested = SelectionKeyUtils.isWriteInterested(channel); - boolean pendingWrites = channel.getContext().hasQueuedWriteOps(); + SelectionKey selectionKey = context.getSelectionKey(); + boolean currentlyWriteInterested = SelectionKeyUtils.isWriteInterested(selectionKey); + boolean pendingWrites = context.hasQueuedWriteOps(); if (currentlyWriteInterested == false && pendingWrites) { - SelectionKeyUtils.setWriteInterested(channel); + SelectionKeyUtils.setWriteInterested(selectionKey); } else if (currentlyWriteInterested && pendingWrites == false) { - SelectionKeyUtils.removeWriteInterested(channel); + SelectionKeyUtils.removeWriteInterested(selectionKey); } } } diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketSelector.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketSelector.java index acfec6ca04e..b1a3a08f02d 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketSelector.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketSelector.java @@ -33,7 +33,7 @@ import java.util.function.BiConsumer; */ public class SocketSelector extends ESSelector { - private final ConcurrentLinkedQueue newChannels = new ConcurrentLinkedQueue<>(); + private final ConcurrentLinkedQueue newChannels = new ConcurrentLinkedQueue<>(); private final ConcurrentLinkedQueue queuedWrites = new ConcurrentLinkedQueue<>(); private final SocketEventHandler eventHandler; @@ -49,23 +49,23 @@ public class SocketSelector extends ESSelector { @Override void processKey(SelectionKey selectionKey) { - NioSocketChannel nioSocketChannel = (NioSocketChannel) selectionKey.attachment(); + SocketChannelContext channelContext = (SocketChannelContext) selectionKey.attachment(); int ops = selectionKey.readyOps(); if ((ops & SelectionKey.OP_CONNECT) != 0) { - attemptConnect(nioSocketChannel, true); + attemptConnect(channelContext, true); } - if (nioSocketChannel.isConnectComplete()) { + if (channelContext.isConnectComplete()) { if ((ops & SelectionKey.OP_WRITE) != 0) { - handleWrite(nioSocketChannel); + handleWrite(channelContext); } if ((ops & SelectionKey.OP_READ) != 0) { - handleRead(nioSocketChannel); + handleRead(channelContext); } } - eventHandler.postHandling(nioSocketChannel); + eventHandler.postHandling(channelContext); } @Override @@ -89,8 +89,9 @@ public class SocketSelector extends ESSelector { * @param nioSocketChannel the channel to register */ public void scheduleForRegistration(NioSocketChannel nioSocketChannel) { - newChannels.offer(nioSocketChannel); - ensureSelectorOpenForEnqueuing(newChannels, nioSocketChannel); + SocketChannelContext channelContext = nioSocketChannel.getContext(); + newChannels.offer(channelContext); + ensureSelectorOpenForEnqueuing(newChannels, channelContext); wakeup(); } @@ -121,10 +122,9 @@ public class SocketSelector extends ESSelector { */ public void queueWriteInChannelBuffer(WriteOperation writeOperation) { assertOnSelectorThread(); - NioSocketChannel channel = writeOperation.getChannel(); - SocketChannelContext context = channel.getContext(); + SocketChannelContext context = writeOperation.getChannel(); try { - SelectionKeyUtils.setWriteInterested(channel); + SelectionKeyUtils.setWriteInterested(context.getSelectionKey()); context.queueWriteOperation(writeOperation); } catch (Exception e) { executeFailedListener(writeOperation.getListener(), e); @@ -163,19 +163,19 @@ public class SocketSelector extends ESSelector { } } - private void handleWrite(NioSocketChannel nioSocketChannel) { + private void handleWrite(SocketChannelContext context) { try { - eventHandler.handleWrite(nioSocketChannel); + eventHandler.handleWrite(context); } catch (Exception e) { - eventHandler.writeException(nioSocketChannel, e); + eventHandler.writeException(context, e); } } - private void handleRead(NioSocketChannel nioSocketChannel) { + private void handleRead(SocketChannelContext context) { try { - eventHandler.handleRead(nioSocketChannel); + eventHandler.handleRead(context); } catch (Exception e) { - eventHandler.readException(nioSocketChannel, e); + eventHandler.readException(context, e); } } @@ -191,38 +191,34 @@ public class SocketSelector extends ESSelector { } private void setUpNewChannels() { - NioSocketChannel newChannel; - while ((newChannel = this.newChannels.poll()) != null) { - setupChannel(newChannel); + SocketChannelContext channelContext; + while ((channelContext = this.newChannels.poll()) != null) { + setupChannel(channelContext); } } - private void setupChannel(NioSocketChannel newChannel) { - assert newChannel.getSelector() == this : "The channel must be registered with the selector with which it was created"; + private void setupChannel(SocketChannelContext context) { + assert context.getSelector() == this : "The channel must be registered with the selector with which it was created"; try { - if (newChannel.isOpen()) { - newChannel.register(); - SelectionKey key = newChannel.getSelectionKey(); - key.attach(newChannel); - eventHandler.handleRegistration(newChannel); - attemptConnect(newChannel, false); + if (context.isOpen()) { + eventHandler.handleRegistration(context); + attemptConnect(context, false); } else { - eventHandler.registrationException(newChannel, new ClosedChannelException()); + eventHandler.registrationException(context, new ClosedChannelException()); } } catch (Exception e) { - eventHandler.registrationException(newChannel, e); + eventHandler.registrationException(context, e); } } - private void attemptConnect(NioSocketChannel newChannel, boolean connectEvent) { + private void attemptConnect(SocketChannelContext context, boolean connectEvent) { try { - if (newChannel.finishConnect()) { - eventHandler.handleConnect(newChannel); - } else if (connectEvent) { - eventHandler.connectException(newChannel, new IOException("Received OP_CONNECT but connect failed")); + eventHandler.handleConnect(context); + if (connectEvent && context.isConnectComplete() == false) { + eventHandler.connectException(context, new IOException("Received OP_CONNECT but connect failed")); } } catch (Exception e) { - eventHandler.connectException(newChannel, e); + eventHandler.connectException(context, e); } } } diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/WriteOperation.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/WriteOperation.java index d2dfe4f37a0..665b9f7759e 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/WriteOperation.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/WriteOperation.java @@ -30,5 +30,5 @@ public interface WriteOperation { BiConsumer getListener(); - NioSocketChannel getChannel(); + SocketChannelContext getChannel(); } diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/AcceptingSelectorTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/AcceptingSelectorTests.java index 048aa3af8ff..7536ad9d1e1 100644 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/AcceptingSelectorTests.java +++ b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/AcceptingSelectorTests.java @@ -44,6 +44,7 @@ public class AcceptingSelectorTests extends ESTestCase { private AcceptorEventHandler eventHandler; private TestSelectionKey selectionKey; private Selector rawSelector; + private ServerChannelContext context; @Before public void setUp() throws Exception { @@ -56,39 +57,41 @@ public class AcceptingSelectorTests extends ESTestCase { selector = new AcceptingSelector(eventHandler, rawSelector); this.selector.setThread(); + context = mock(ServerChannelContext.class); selectionKey = new TestSelectionKey(0); - selectionKey.attach(serverChannel); - when(serverChannel.getSelectionKey()).thenReturn(selectionKey); - when(serverChannel.getSelector()).thenReturn(selector); - when(serverChannel.isOpen()).thenReturn(true); + selectionKey.attach(context); + when(context.getSelectionKey()).thenReturn(selectionKey); + when(context.getSelector()).thenReturn(selector); + when(context.isOpen()).thenReturn(true); + when(serverChannel.getContext()).thenReturn(context); } - public void testRegisteredChannel() throws IOException, PrivilegedActionException { + public void testRegisteredChannel() throws IOException { selector.scheduleForRegistration(serverChannel); selector.preSelect(); - verify(eventHandler).serverChannelRegistered(serverChannel); + verify(eventHandler).handleRegistration(context); } - public void testClosedChannelWillNotBeRegistered() throws Exception { - when(serverChannel.isOpen()).thenReturn(false); + public void testClosedChannelWillNotBeRegistered() { + when(context.isOpen()).thenReturn(false); selector.scheduleForRegistration(serverChannel); selector.preSelect(); - verify(eventHandler).registrationException(same(serverChannel), any(ClosedChannelException.class)); + verify(eventHandler).registrationException(same(context), any(ClosedChannelException.class)); } public void testRegisterChannelFailsDueToException() throws Exception { selector.scheduleForRegistration(serverChannel); ClosedChannelException closedChannelException = new ClosedChannelException(); - doThrow(closedChannelException).when(serverChannel).register(); + doThrow(closedChannelException).when(eventHandler).handleRegistration(context); selector.preSelect(); - verify(eventHandler).registrationException(serverChannel, closedChannelException); + verify(eventHandler).registrationException(context, closedChannelException); } public void testAcceptEvent() throws IOException { @@ -96,18 +99,18 @@ public class AcceptingSelectorTests extends ESTestCase { selector.processKey(selectionKey); - verify(eventHandler).acceptChannel(serverChannel); + verify(eventHandler).acceptChannel(context); } public void testAcceptException() throws IOException { selectionKey.setReadyOps(SelectionKey.OP_ACCEPT); IOException ioException = new IOException(); - doThrow(ioException).when(eventHandler).acceptChannel(serverChannel); + doThrow(ioException).when(eventHandler).acceptChannel(context); selector.processKey(selectionKey); - verify(eventHandler).acceptException(serverChannel, ioException); + verify(eventHandler).acceptException(context, ioException); } public void testCleanup() throws IOException { @@ -116,11 +119,11 @@ public class AcceptingSelectorTests extends ESTestCase { selector.preSelect(); TestSelectionKey key = new TestSelectionKey(0); - key.attach(serverChannel); + key.attach(context); when(rawSelector.keys()).thenReturn(new HashSet<>(Collections.singletonList(key))); selector.cleanupAndCloseChannels(); - verify(eventHandler).handleClose(serverChannel); + verify(eventHandler).handleClose(context); } } diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/AcceptorEventHandlerTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/AcceptorEventHandlerTests.java index 23ab3bb3e1d..50469b30acd 100644 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/AcceptorEventHandlerTests.java +++ b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/AcceptorEventHandlerTests.java @@ -27,62 +27,89 @@ import java.nio.channels.SelectionKey; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; import java.util.ArrayList; +import java.util.function.Consumer; import static org.mockito.Matchers.same; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class AcceptorEventHandlerTests extends ESTestCase { private AcceptorEventHandler handler; - private SocketSelector socketSelector; private ChannelFactory channelFactory; private NioServerSocketChannel channel; - private ServerChannelContext context; + private DoNotRegisterContext context; + private RoundRobinSupplier selectorSupplier; @Before @SuppressWarnings("unchecked") public void setUpHandler() throws IOException { channelFactory = mock(ChannelFactory.class); - socketSelector = mock(SocketSelector.class); - context = mock(ServerChannelContext.class); ArrayList selectors = new ArrayList<>(); - selectors.add(socketSelector); - handler = new AcceptorEventHandler(logger, new RoundRobinSupplier<>(selectors.toArray(new SocketSelector[selectors.size()]))); + selectors.add(mock(SocketSelector.class)); + selectorSupplier = new RoundRobinSupplier<>(selectors.toArray(new SocketSelector[selectors.size()])); + handler = new AcceptorEventHandler(logger, selectorSupplier); - AcceptingSelector selector = mock(AcceptingSelector.class); - channel = new DoNotRegisterServerChannel(mock(ServerSocketChannel.class), channelFactory, selector); + channel = new NioServerSocketChannel(mock(ServerSocketChannel.class)); + context = new DoNotRegisterContext(channel, mock(AcceptingSelector.class), mock(Consumer.class)); channel.setContext(context); - channel.register(); } - public void testHandleRegisterSetsOP_ACCEPTInterest() { - assertEquals(0, channel.getSelectionKey().interestOps()); + public void testHandleRegisterSetsOP_ACCEPTInterest() throws IOException { + assertNull(context.getSelectionKey()); - handler.serverChannelRegistered(channel); + handler.handleRegistration(context); - assertEquals(SelectionKey.OP_ACCEPT, channel.getSelectionKey().interestOps()); + assertEquals(SelectionKey.OP_ACCEPT, channel.getContext().getSelectionKey().interestOps()); + } + + public void testRegisterAddsAttachment() throws IOException { + assertNull(context.getSelectionKey()); + + handler.handleRegistration(context); + + assertEquals(context, context.getSelectionKey().attachment()); } public void testHandleAcceptCallsChannelFactory() throws IOException { - NioSocketChannel childChannel = new NioSocketChannel(mock(SocketChannel.class), socketSelector); - when(channelFactory.acceptNioChannel(same(channel), same(socketSelector))).thenReturn(childChannel); + NioSocketChannel childChannel = new NioSocketChannel(mock(SocketChannel.class)); + NioSocketChannel nullChannel = null; + when(channelFactory.acceptNioChannel(same(context), same(selectorSupplier))).thenReturn(childChannel, nullChannel); - handler.acceptChannel(channel); - - verify(channelFactory).acceptNioChannel(same(channel), same(socketSelector)); + handler.acceptChannel(context); + verify(channelFactory, times(2)).acceptNioChannel(same(context), same(selectorSupplier)); } @SuppressWarnings("unchecked") public void testHandleAcceptCallsServerAcceptCallback() throws IOException { - NioSocketChannel childChannel = new NioSocketChannel(mock(SocketChannel.class), socketSelector); - childChannel.setContext(mock(SocketChannelContext.class)); - when(channelFactory.acceptNioChannel(same(channel), same(socketSelector))).thenReturn(childChannel); + NioSocketChannel childChannel = new NioSocketChannel(mock(SocketChannel.class)); + SocketChannelContext childContext = mock(SocketChannelContext.class); + childChannel.setContext(childContext); + ServerChannelContext serverChannelContext = mock(ServerChannelContext.class); + channel = new NioServerSocketChannel(mock(ServerSocketChannel.class)); + channel.setContext(serverChannelContext); + when(serverChannelContext.getChannel()).thenReturn(channel); + when(channelFactory.acceptNioChannel(same(context), same(selectorSupplier))).thenReturn(childChannel); - handler.acceptChannel(channel); + handler.acceptChannel(serverChannelContext); - verify(context).acceptChannel(childChannel); + verify(serverChannelContext).acceptChannels(selectorSupplier); + } + + private class DoNotRegisterContext extends ServerChannelContext { + + + @SuppressWarnings("unchecked") + DoNotRegisterContext(NioServerSocketChannel channel, AcceptingSelector selector, Consumer acceptor) { + super(channel, channelFactory, selector, acceptor, mock(Consumer.class)); + } + + @Override + public void register() { + setSelectionKey(new TestSelectionKey(0)); + } } } diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesChannelContextTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesChannelContextTests.java index 68ae1f2e503..d9de0ab1361 100644 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesChannelContextTests.java +++ b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesChannelContextTests.java @@ -27,10 +27,13 @@ import org.mockito.ArgumentCaptor; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.ClosedChannelException; +import java.nio.channels.SocketChannel; import java.util.function.BiConsumer; +import java.util.function.Consumer; import java.util.function.Supplier; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyInt; import static org.mockito.Matchers.isNull; import static org.mockito.Matchers.same; import static org.mockito.Mockito.mock; @@ -42,9 +45,11 @@ public class BytesChannelContextTests extends ESTestCase { private SocketChannelContext.ReadConsumer readConsumer; private NioSocketChannel channel; + private SocketChannel rawChannel; private BytesChannelContext context; private InboundChannelBuffer channelBuffer; private SocketSelector selector; + private Consumer exceptionHandler; private BiConsumer listener; private int messageLength; @@ -57,17 +62,19 @@ public class BytesChannelContextTests extends ESTestCase { selector = mock(SocketSelector.class); listener = mock(BiConsumer.class); channel = mock(NioSocketChannel.class); + rawChannel = mock(SocketChannel.class); channelBuffer = InboundChannelBuffer.allocatingInstance(); - context = new BytesChannelContext(channel, null, readConsumer, channelBuffer); + exceptionHandler = mock(Consumer.class); + when(channel.getRawChannel()).thenReturn(rawChannel); + context = new BytesChannelContext(channel, selector, exceptionHandler, readConsumer, channelBuffer); - when(channel.getSelector()).thenReturn(selector); when(selector.isOnCurrentThread()).thenReturn(true); } public void testSuccessfulRead() throws IOException { byte[] bytes = createMessage(messageLength); - when(channel.read(any(ByteBuffer[].class))).thenAnswer(invocationOnMock -> { + when(rawChannel.read(any(ByteBuffer[].class), anyInt(), anyInt())).thenAnswer(invocationOnMock -> { ByteBuffer[] buffers = (ByteBuffer[]) invocationOnMock.getArguments()[0]; buffers[0].put(bytes); return bytes.length; @@ -85,7 +92,7 @@ public class BytesChannelContextTests extends ESTestCase { public void testMultipleReadsConsumed() throws IOException { byte[] bytes = createMessage(messageLength * 2); - when(channel.read(any(ByteBuffer[].class))).thenAnswer(invocationOnMock -> { + when(rawChannel.read(any(ByteBuffer[].class), anyInt(), anyInt())).thenAnswer(invocationOnMock -> { ByteBuffer[] buffers = (ByteBuffer[]) invocationOnMock.getArguments()[0]; buffers[0].put(bytes); return bytes.length; @@ -103,7 +110,7 @@ public class BytesChannelContextTests extends ESTestCase { public void testPartialRead() throws IOException { byte[] bytes = createMessage(messageLength); - when(channel.read(any(ByteBuffer[].class))).thenAnswer(invocationOnMock -> { + when(rawChannel.read(any(ByteBuffer[].class), anyInt(), anyInt())).thenAnswer(invocationOnMock -> { ByteBuffer[] buffers = (ByteBuffer[]) invocationOnMock.getArguments()[0]; buffers[0].put(bytes); return bytes.length; @@ -128,14 +135,14 @@ public class BytesChannelContextTests extends ESTestCase { public void testReadThrowsIOException() throws IOException { IOException ioException = new IOException(); - when(channel.read(any(ByteBuffer[].class))).thenThrow(ioException); + when(rawChannel.read(any(ByteBuffer[].class), anyInt(), anyInt())).thenThrow(ioException); IOException ex = expectThrows(IOException.class, () -> context.read()); assertSame(ioException, ex); } public void testReadThrowsIOExceptionMeansReadyForClose() throws IOException { - when(channel.read(any(ByteBuffer[].class))).thenThrow(new IOException()); + when(rawChannel.read(any(ByteBuffer[].class), anyInt(), anyInt())).thenThrow(new IOException()); assertFalse(context.selectorShouldClose()); expectThrows(IOException.class, () -> context.read()); @@ -143,22 +150,28 @@ public class BytesChannelContextTests extends ESTestCase { } public void testReadLessThanZeroMeansReadyForClose() throws IOException { - when(channel.read(any(ByteBuffer[].class))).thenReturn(-1); + when(rawChannel.read(any(ByteBuffer[].class), anyInt(), anyInt())).thenReturn(-1L); assertEquals(0, context.read()); assertTrue(context.selectorShouldClose()); } + @SuppressWarnings("unchecked") public void testCloseClosesChannelBuffer() throws IOException { - when(channel.isOpen()).thenReturn(true); - Runnable closer = mock(Runnable.class); - Supplier pageSupplier = () -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14), closer); - InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier); - buffer.ensureCapacity(1); - BytesChannelContext context = new BytesChannelContext(channel, null, readConsumer, buffer); - context.closeFromSelector(); - verify(closer).run(); + try (SocketChannel realChannel = SocketChannel.open()) { + when(channel.getRawChannel()).thenReturn(realChannel); + context = new BytesChannelContext(channel, selector, exceptionHandler, readConsumer, channelBuffer); + + when(channel.isOpen()).thenReturn(true); + Runnable closer = mock(Runnable.class); + Supplier pageSupplier = () -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14), closer); + InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier); + buffer.ensureCapacity(1); + BytesChannelContext context = new BytesChannelContext(channel, selector, exceptionHandler, readConsumer, buffer); + context.closeFromSelector(); + verify(closer).run(); + } } public void testWriteFailsIfClosing() { @@ -182,7 +195,7 @@ public class BytesChannelContextTests extends ESTestCase { BytesWriteOperation writeOp = writeOpCaptor.getValue(); assertSame(listener, writeOp.getListener()); - assertSame(channel, writeOp.getChannel()); + assertSame(context, writeOp.getChannel()); assertEquals(buffers[0], writeOp.getBuffersToWrite()[0]); } @@ -196,7 +209,7 @@ public class BytesChannelContextTests extends ESTestCase { BytesWriteOperation writeOp = writeOpCaptor.getValue(); assertSame(listener, writeOp.getListener()); - assertSame(channel, writeOp.getChannel()); + assertSame(context, writeOp.getChannel()); assertEquals(buffers[0], writeOp.getBuffersToWrite()[0]); } @@ -204,25 +217,31 @@ public class BytesChannelContextTests extends ESTestCase { assertFalse(context.hasQueuedWriteOps()); ByteBuffer[] buffer = {ByteBuffer.allocate(10)}; - context.queueWriteOperation(new BytesWriteOperation(channel, buffer, listener)); + context.queueWriteOperation(new BytesWriteOperation(context, buffer, listener)); assertTrue(context.hasQueuedWriteOps()); } + @SuppressWarnings("unchecked") public void testWriteOpsClearedOnClose() throws Exception { - assertFalse(context.hasQueuedWriteOps()); + try (SocketChannel realChannel = SocketChannel.open()) { + when(channel.getRawChannel()).thenReturn(realChannel); + context = new BytesChannelContext(channel, selector, exceptionHandler, readConsumer, channelBuffer); - ByteBuffer[] buffer = {ByteBuffer.allocate(10)}; - context.queueWriteOperation(new BytesWriteOperation(channel, buffer, listener)); + assertFalse(context.hasQueuedWriteOps()); - assertTrue(context.hasQueuedWriteOps()); + ByteBuffer[] buffer = {ByteBuffer.allocate(10)}; + context.queueWriteOperation(new BytesWriteOperation(context, buffer, listener)); - when(channel.isOpen()).thenReturn(true); - context.closeFromSelector(); + assertTrue(context.hasQueuedWriteOps()); - verify(selector).executeFailedListener(same(listener), any(ClosedChannelException.class)); + when(channel.isOpen()).thenReturn(true); + context.closeFromSelector(); - assertFalse(context.hasQueuedWriteOps()); + verify(selector).executeFailedListener(same(listener), any(ClosedChannelException.class)); + + assertFalse(context.hasQueuedWriteOps()); + } } public void testQueuedWriteIsFlushedInFlushCall() throws Exception { @@ -239,7 +258,7 @@ public class BytesChannelContextTests extends ESTestCase { when(writeOperation.getListener()).thenReturn(listener); context.flushChannel(); - verify(channel).write(buffers); + verify(rawChannel).write(buffers, 0, buffers.length); verify(selector).executeListener(listener, null); assertFalse(context.hasQueuedWriteOps()); } @@ -253,6 +272,7 @@ public class BytesChannelContextTests extends ESTestCase { assertTrue(context.hasQueuedWriteOps()); when(writeOperation.isFullyFlushed()).thenReturn(false); + when(writeOperation.getBuffersToWrite()).thenReturn(new ByteBuffer[0]); context.flushChannel(); verify(listener, times(0)).accept(null, null); @@ -266,6 +286,8 @@ public class BytesChannelContextTests extends ESTestCase { BiConsumer listener2 = mock(BiConsumer.class); BytesWriteOperation writeOperation1 = mock(BytesWriteOperation.class); BytesWriteOperation writeOperation2 = mock(BytesWriteOperation.class); + when(writeOperation1.getBuffersToWrite()).thenReturn(new ByteBuffer[0]); + when(writeOperation2.getBuffersToWrite()).thenReturn(new ByteBuffer[0]); when(writeOperation1.getListener()).thenReturn(listener); when(writeOperation2.getListener()).thenReturn(listener2); context.queueWriteOperation(writeOperation1); @@ -300,7 +322,7 @@ public class BytesChannelContextTests extends ESTestCase { IOException exception = new IOException(); when(writeOperation.getBuffersToWrite()).thenReturn(buffers); - when(channel.write(buffers)).thenThrow(exception); + when(rawChannel.write(buffers, 0, buffers.length)).thenThrow(exception); when(writeOperation.getListener()).thenReturn(listener); expectThrows(IOException.class, () -> context.flushChannel()); @@ -315,7 +337,7 @@ public class BytesChannelContextTests extends ESTestCase { IOException exception = new IOException(); when(writeOperation.getBuffersToWrite()).thenReturn(buffers); - when(channel.write(buffers)).thenThrow(exception); + when(rawChannel.write(buffers, 0, buffers.length)).thenThrow(exception); assertFalse(context.selectorShouldClose()); expectThrows(IOException.class, () -> context.flushChannel()); diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/WriteOperationTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesWriteOperationTests.java similarity index 91% rename from libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/WriteOperationTests.java rename to libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesWriteOperationTests.java index 59fb9cde438..05afc80a490 100644 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/WriteOperationTests.java +++ b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesWriteOperationTests.java @@ -25,29 +25,26 @@ import org.mockito.ArgumentCaptor; import java.io.IOException; import java.nio.ByteBuffer; -import java.util.List; import java.util.function.BiConsumer; -import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; -public class WriteOperationTests extends ESTestCase { +public class BytesWriteOperationTests extends ESTestCase { - private NioSocketChannel channel; + private SocketChannelContext channelContext; private BiConsumer listener; @Before @SuppressWarnings("unchecked") public void setFields() { - channel = mock(NioSocketChannel.class); + channelContext = mock(SocketChannelContext.class); listener = mock(BiConsumer.class); } public void testFullyFlushedMarker() { ByteBuffer[] buffers = {ByteBuffer.allocate(10)}; - BytesWriteOperation writeOp = new BytesWriteOperation(channel, buffers, listener); + BytesWriteOperation writeOp = new BytesWriteOperation(channelContext, buffers, listener); writeOp.incrementIndex(10); @@ -56,7 +53,7 @@ public class WriteOperationTests extends ESTestCase { public void testPartiallyFlushedMarker() { ByteBuffer[] buffers = {ByteBuffer.allocate(10)}; - BytesWriteOperation writeOp = new BytesWriteOperation(channel, buffers, listener); + BytesWriteOperation writeOp = new BytesWriteOperation(channelContext, buffers, listener); writeOp.incrementIndex(5); @@ -65,7 +62,7 @@ public class WriteOperationTests extends ESTestCase { public void testMultipleFlushesWithCompositeBuffer() throws IOException { ByteBuffer[] buffers = {ByteBuffer.allocate(10), ByteBuffer.allocate(15), ByteBuffer.allocate(3)}; - BytesWriteOperation writeOp = new BytesWriteOperation(channel, buffers, listener); + BytesWriteOperation writeOp = new BytesWriteOperation(channelContext, buffers, listener); ArgumentCaptor buffersCaptor = ArgumentCaptor.forClass(ByteBuffer[].class); diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/ChannelContextTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/ChannelContextTests.java new file mode 100644 index 00000000000..f262dd06330 --- /dev/null +++ b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/ChannelContextTests.java @@ -0,0 +1,214 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.nio; + +import org.elasticsearch.test.ESTestCase; +import org.junit.Before; + +import java.io.IOException; +import java.net.SocketAddress; +import java.net.SocketOption; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.NetworkChannel; +import java.nio.channels.SelectableChannel; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.spi.SelectorProvider; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +public class ChannelContextTests extends ESTestCase { + + private TestChannelContext context; + private Consumer exceptionHandler; + + @SuppressWarnings("unchecked") + @Before + public void setup() throws Exception { + super.setUp(); + exceptionHandler = mock(Consumer.class); + } + + public void testCloseSuccess() throws IOException { + FakeRawChannel rawChannel = new FakeRawChannel(null); + context = new TestChannelContext(rawChannel, exceptionHandler); + + AtomicBoolean listenerCalled = new AtomicBoolean(false); + context.addCloseListener((v, t) -> { + if (t == null) { + listenerCalled.compareAndSet(false, true); + } else { + throw new AssertionError("Close should not fail"); + } + }); + + assertFalse(rawChannel.hasCloseBeenCalled()); + assertTrue(context.isOpen()); + assertFalse(listenerCalled.get()); + context.closeFromSelector(); + assertTrue(rawChannel.hasCloseBeenCalled()); + assertFalse(context.isOpen()); + assertTrue(listenerCalled.get()); + } + + public void testCloseException() throws IOException { + IOException ioException = new IOException("boom"); + FakeRawChannel rawChannel = new FakeRawChannel(ioException); + context = new TestChannelContext(rawChannel, exceptionHandler); + + AtomicReference exception = new AtomicReference<>(); + context.addCloseListener((v, t) -> { + if (t == null) { + throw new AssertionError("Close should not fail"); + } else { + exception.set((Exception) t); + } + }); + + assertFalse(rawChannel.hasCloseBeenCalled()); + assertTrue(context.isOpen()); + assertNull(exception.get()); + expectThrows(IOException.class, context::closeFromSelector); + assertTrue(rawChannel.hasCloseBeenCalled()); + assertFalse(context.isOpen()); + assertSame(ioException, exception.get()); + } + + public void testExceptionsAreDelegatedToHandler() { + context = new TestChannelContext(new FakeRawChannel(null), exceptionHandler); + IOException exception = new IOException(); + context.handleException(exception); + verify(exceptionHandler).accept(exception); + } + + private static class TestChannelContext extends ChannelContext { + + private TestChannelContext(FakeRawChannel channel, Consumer exceptionHandler) { + super(channel, exceptionHandler); + } + + @Override + public void closeChannel() { + throw new UnsupportedOperationException("not implemented"); + } + + @Override + public ESSelector getSelector() { + throw new UnsupportedOperationException("not implemented"); + } + + @Override + public NioChannel getChannel() { + throw new UnsupportedOperationException("not implemented"); + } + } + + private class FakeRawChannel extends SelectableChannel implements NetworkChannel { + + private final IOException exceptionOnClose; + private AtomicBoolean hasCloseBeenCalled = new AtomicBoolean(false); + + private FakeRawChannel(IOException exceptionOnClose) { + this.exceptionOnClose = exceptionOnClose; + } + + @Override + protected void implCloseChannel() throws IOException { + hasCloseBeenCalled.compareAndSet(false, true); + if (exceptionOnClose != null) { + throw exceptionOnClose; + } + } + + private boolean hasCloseBeenCalled() { + return hasCloseBeenCalled.get(); + } + + @Override + public NetworkChannel bind(SocketAddress local) throws IOException { + throw new UnsupportedOperationException("not implemented"); + } + + @Override + public SocketAddress getLocalAddress() throws IOException { + throw new UnsupportedOperationException("not implemented"); + } + + @Override + public NetworkChannel setOption(SocketOption name, T value) throws IOException { + throw new UnsupportedOperationException("not implemented"); + } + + @Override + public T getOption(SocketOption name) throws IOException { + throw new UnsupportedOperationException("not implemented"); + } + + @Override + public Set> supportedOptions() { + throw new UnsupportedOperationException("not implemented"); + } + + @Override + public SelectorProvider provider() { + throw new UnsupportedOperationException("not implemented"); + } + + @Override + public int validOps() { + throw new UnsupportedOperationException("not implemented"); + } + + @Override + public boolean isRegistered() { + throw new UnsupportedOperationException("not implemented"); + } + + @Override + public SelectionKey keyFor(Selector sel) { + throw new UnsupportedOperationException("not implemented"); + } + + @Override + public SelectionKey register(Selector sel, int ops, Object att) throws ClosedChannelException { + throw new UnsupportedOperationException("not implemented"); + } + + @Override + public SelectableChannel configureBlocking(boolean block) throws IOException { + throw new UnsupportedOperationException("not implemented"); + } + + @Override + public boolean isBlocking() { + throw new UnsupportedOperationException("not implemented"); + } + + @Override + public Object blockingLock() { + throw new UnsupportedOperationException("not implemented"); + } + } +} diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/ChannelFactoryTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/ChannelFactoryTests.java index 1c8a8a130cc..99880a2fd80 100644 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/ChannelFactoryTests.java +++ b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/ChannelFactoryTests.java @@ -28,6 +28,7 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; +import java.util.function.Supplier; import static org.mockito.Matchers.any; import static org.mockito.Matchers.same; @@ -43,6 +44,8 @@ public class ChannelFactoryTests extends ESTestCase { private SocketChannel rawChannel; private ServerSocketChannel rawServerChannel; private SocketSelector socketSelector; + private Supplier socketSelectorSupplier; + private Supplier acceptingSelectorSupplier; private AcceptingSelector acceptingSelector; @Before @@ -52,34 +55,36 @@ public class ChannelFactoryTests extends ESTestCase { channelFactory = new TestChannelFactory(rawChannelFactory); socketSelector = mock(SocketSelector.class); acceptingSelector = mock(AcceptingSelector.class); + socketSelectorSupplier = mock(Supplier.class); + acceptingSelectorSupplier = mock(Supplier.class); rawChannel = SocketChannel.open(); rawServerChannel = ServerSocketChannel.open(); + when(socketSelectorSupplier.get()).thenReturn(socketSelector); + when(acceptingSelectorSupplier.get()).thenReturn(acceptingSelector); } @After public void ensureClosed() throws IOException { - IOUtils.closeWhileHandlingException(rawChannel); - IOUtils.closeWhileHandlingException(rawServerChannel); + IOUtils.closeWhileHandlingException(rawChannel, rawServerChannel); } public void testAcceptChannel() throws IOException { - NioServerSocketChannel serverChannel = mock(NioServerSocketChannel.class); - when(rawChannelFactory.acceptNioChannel(serverChannel)).thenReturn(rawChannel); + ServerChannelContext serverChannelContext = mock(ServerChannelContext.class); + when(rawChannelFactory.acceptNioChannel(serverChannelContext)).thenReturn(rawChannel); - NioSocketChannel channel = channelFactory.acceptNioChannel(serverChannel, socketSelector); + NioSocketChannel channel = channelFactory.acceptNioChannel(serverChannelContext, socketSelectorSupplier); verify(socketSelector).scheduleForRegistration(channel); - assertEquals(socketSelector, channel.getSelector()); assertEquals(rawChannel, channel.getRawChannel()); } public void testAcceptedChannelRejected() throws IOException { - NioServerSocketChannel serverChannel = mock(NioServerSocketChannel.class); - when(rawChannelFactory.acceptNioChannel(serverChannel)).thenReturn(rawChannel); + ServerChannelContext serverChannelContext = mock(ServerChannelContext.class); + when(rawChannelFactory.acceptNioChannel(serverChannelContext)).thenReturn(rawChannel); doThrow(new IllegalStateException()).when(socketSelector).scheduleForRegistration(any()); - expectThrows(IllegalStateException.class, () -> channelFactory.acceptNioChannel(serverChannel, socketSelector)); + expectThrows(IllegalStateException.class, () -> channelFactory.acceptNioChannel(serverChannelContext, socketSelectorSupplier)); assertFalse(rawChannel.isOpen()); } @@ -88,11 +93,10 @@ public class ChannelFactoryTests extends ESTestCase { InetSocketAddress address = mock(InetSocketAddress.class); when(rawChannelFactory.openNioChannel(same(address))).thenReturn(rawChannel); - NioSocketChannel channel = channelFactory.openNioChannel(address, socketSelector); + NioSocketChannel channel = channelFactory.openNioChannel(address, socketSelectorSupplier); verify(socketSelector).scheduleForRegistration(channel); - assertEquals(socketSelector, channel.getSelector()); assertEquals(rawChannel, channel.getRawChannel()); } @@ -101,7 +105,7 @@ public class ChannelFactoryTests extends ESTestCase { when(rawChannelFactory.openNioChannel(same(address))).thenReturn(rawChannel); doThrow(new IllegalStateException()).when(socketSelector).scheduleForRegistration(any()); - expectThrows(IllegalStateException.class, () -> channelFactory.openNioChannel(address, socketSelector)); + expectThrows(IllegalStateException.class, () -> channelFactory.openNioChannel(address, socketSelectorSupplier)); assertFalse(rawChannel.isOpen()); } @@ -110,11 +114,10 @@ public class ChannelFactoryTests extends ESTestCase { InetSocketAddress address = mock(InetSocketAddress.class); when(rawChannelFactory.openNioServerSocketChannel(same(address))).thenReturn(rawServerChannel); - NioServerSocketChannel channel = channelFactory.openNioServerSocketChannel(address, acceptingSelector); + NioServerSocketChannel channel = channelFactory.openNioServerSocketChannel(address, acceptingSelectorSupplier); verify(acceptingSelector).scheduleForRegistration(channel); - assertEquals(acceptingSelector, channel.getSelector()); assertEquals(rawServerChannel, channel.getRawChannel()); } @@ -123,7 +126,7 @@ public class ChannelFactoryTests extends ESTestCase { when(rawChannelFactory.openNioServerSocketChannel(same(address))).thenReturn(rawServerChannel); doThrow(new IllegalStateException()).when(acceptingSelector).scheduleForRegistration(any()); - expectThrows(IllegalStateException.class, () -> channelFactory.openNioServerSocketChannel(address, acceptingSelector)); + expectThrows(IllegalStateException.class, () -> channelFactory.openNioServerSocketChannel(address, acceptingSelectorSupplier)); assertFalse(rawServerChannel.isOpen()); } @@ -137,14 +140,14 @@ public class ChannelFactoryTests extends ESTestCase { @SuppressWarnings("unchecked") @Override public NioSocketChannel createChannel(SocketSelector selector, SocketChannel channel) throws IOException { - NioSocketChannel nioSocketChannel = new NioSocketChannel(channel, selector); + NioSocketChannel nioSocketChannel = new NioSocketChannel(channel); nioSocketChannel.setContext(mock(SocketChannelContext.class)); return nioSocketChannel; } @Override public NioServerSocketChannel createServerChannel(AcceptingSelector selector, ServerSocketChannel channel) throws IOException { - return new NioServerSocketChannel(channel, this, selector); + return new NioServerSocketChannel(channel); } } } diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/DoNotRegisterChannel.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/DoNotRegisterChannel.java deleted file mode 100644 index dd73d43292a..00000000000 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/DoNotRegisterChannel.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.nio; - -import java.io.IOException; -import java.nio.channels.ClosedChannelException; -import java.nio.channels.SocketChannel; - -public class DoNotRegisterChannel extends NioSocketChannel { - - public DoNotRegisterChannel(SocketChannel socketChannel, SocketSelector selector) throws IOException { - super(socketChannel, selector); - } - - @Override - public void register() throws ClosedChannelException { - setSelectionKey(new TestSelectionKey(0)); - } -} diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/DoNotRegisterServerChannel.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/DoNotRegisterServerChannel.java deleted file mode 100644 index 1d5e605c444..00000000000 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/DoNotRegisterServerChannel.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.nio; - -import java.io.IOException; -import java.nio.channels.ClosedChannelException; -import java.nio.channels.ServerSocketChannel; - -public class DoNotRegisterServerChannel extends NioServerSocketChannel { - - public DoNotRegisterServerChannel(ServerSocketChannel channel, ChannelFactory channelFactory, AcceptingSelector selector) - throws IOException { - super(channel, channelFactory, selector); - } - - @Override - public void register() throws ClosedChannelException { - setSelectionKey(new TestSelectionKey(0)); - } -} diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/ESSelectorTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/ESSelectorTests.java index 1dab05487a1..cb8f0757fb9 100644 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/ESSelectorTests.java +++ b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/ESSelectorTests.java @@ -27,6 +27,7 @@ import java.nio.channels.CancelledKeyException; import java.nio.channels.ClosedSelectorException; import java.nio.channels.SelectionKey; import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; import static org.mockito.Matchers.anyInt; import static org.mockito.Mockito.mock; @@ -47,15 +48,18 @@ public class ESSelectorTests extends ESTestCase { selector = new TestSelector(handler, rawSelector); } + @SuppressWarnings({"unchecked", "rawtypes"}) public void testQueueChannelForClosed() throws IOException { NioChannel channel = mock(NioChannel.class); - when(channel.getSelector()).thenReturn(selector); + ChannelContext context = mock(ChannelContext.class); + when(channel.getContext()).thenReturn(context); + when(context.getSelector()).thenReturn(selector); selector.queueChannelClose(channel); selector.singleLoop(); - verify(handler).handleClose(channel); + verify(handler).handleClose(context); } public void testSelectorClosedExceptionIsNotCaughtWhileRunning() throws IOException { diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/NioServerSocketChannelTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/NioServerSocketChannelTests.java deleted file mode 100644 index 12a77a425eb..00000000000 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/NioServerSocketChannelTests.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.nio; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.support.PlainActionFuture; -import org.elasticsearch.common.util.concurrent.FutureUtils; -import org.elasticsearch.test.ESTestCase; -import org.junit.After; -import org.junit.Before; - -import java.io.IOException; -import java.nio.channels.ServerSocketChannel; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.BiConsumer; -import java.util.function.Consumer; -import java.util.function.Supplier; - -import static org.mockito.Mockito.mock; - -public class NioServerSocketChannelTests extends ESTestCase { - - private AcceptingSelector selector; - private AtomicBoolean closedRawChannel; - private Thread thread; - - @Before - @SuppressWarnings("unchecked") - public void setSelector() throws IOException { - selector = new AcceptingSelector(new AcceptorEventHandler(logger, mock(Supplier.class))); - thread = new Thread(selector::runLoop); - closedRawChannel = new AtomicBoolean(false); - thread.start(); - FutureUtils.get(selector.isRunningFuture()); - } - - @After - public void stopSelector() throws IOException, InterruptedException { - selector.close(); - thread.join(); - } - - @SuppressWarnings("unchecked") - public void testClose() throws Exception { - AtomicBoolean isClosed = new AtomicBoolean(false); - CountDownLatch latch = new CountDownLatch(1); - - try (ServerSocketChannel rawChannel = ServerSocketChannel.open()) { - NioServerSocketChannel channel = new NioServerSocketChannel(rawChannel, mock(ChannelFactory.class), selector); - channel.setContext(new ServerChannelContext(channel, mock(Consumer.class), mock(BiConsumer.class))); - channel.addCloseListener(ActionListener.toBiConsumer(new ActionListener() { - @Override - public void onResponse(Void o) { - isClosed.set(true); - latch.countDown(); - } - - @Override - public void onFailure(Exception e) { - isClosed.set(true); - latch.countDown(); - } - })); - - assertTrue(channel.isOpen()); - assertTrue(rawChannel.isOpen()); - assertFalse(isClosed.get()); - - PlainActionFuture closeFuture = PlainActionFuture.newFuture(); - channel.addCloseListener(ActionListener.toBiConsumer(closeFuture)); - selector.queueChannelClose(channel); - closeFuture.actionGet(); - - - assertFalse(rawChannel.isOpen()); - assertFalse(channel.isOpen()); - latch.await(); - assertTrue(isClosed.get()); - } - } -} diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/NioSocketChannelTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/NioSocketChannelTests.java deleted file mode 100644 index bbda9233bbb..00000000000 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/NioSocketChannelTests.java +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.nio; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.support.PlainActionFuture; -import org.elasticsearch.common.util.concurrent.FutureUtils; -import org.elasticsearch.test.ESTestCase; -import org.junit.After; -import org.junit.Before; - -import java.io.IOException; -import java.net.ConnectException; -import java.nio.channels.SocketChannel; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.BiConsumer; - -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -public class NioSocketChannelTests extends ESTestCase { - - private SocketSelector selector; - private Thread thread; - - @Before - @SuppressWarnings("unchecked") - public void startSelector() throws IOException { - selector = new SocketSelector(new SocketEventHandler(logger)); - thread = new Thread(selector::runLoop); - thread.start(); - FutureUtils.get(selector.isRunningFuture()); - } - - @After - public void stopSelector() throws IOException, InterruptedException { - selector.close(); - thread.join(); - } - - @SuppressWarnings("unchecked") - public void testClose() throws Exception { - AtomicBoolean isClosed = new AtomicBoolean(false); - CountDownLatch latch = new CountDownLatch(1); - - try(SocketChannel rawChannel = SocketChannel.open()) { - NioSocketChannel socketChannel = new NioSocketChannel(rawChannel, selector); - socketChannel.setContext(new BytesChannelContext(socketChannel, mock(BiConsumer.class), - mock(SocketChannelContext.ReadConsumer.class), InboundChannelBuffer.allocatingInstance())); - socketChannel.addCloseListener(ActionListener.toBiConsumer(new ActionListener() { - @Override - public void onResponse(Void o) { - isClosed.set(true); - latch.countDown(); - } - - @Override - public void onFailure(Exception e) { - isClosed.set(true); - latch.countDown(); - } - })); - - assertTrue(socketChannel.isOpen()); - assertTrue(rawChannel.isOpen()); - assertFalse(isClosed.get()); - - PlainActionFuture closeFuture = PlainActionFuture.newFuture(); - socketChannel.addCloseListener(ActionListener.toBiConsumer(closeFuture)); - selector.queueChannelClose(socketChannel); - closeFuture.actionGet(); - - assertFalse(rawChannel.isOpen()); - assertFalse(socketChannel.isOpen()); - latch.await(); - assertTrue(isClosed.get()); - } - } - - @SuppressWarnings("unchecked") - public void testConnectSucceeds() throws Exception { - SocketChannel rawChannel = mock(SocketChannel.class); - when(rawChannel.finishConnect()).thenReturn(true); - NioSocketChannel socketChannel = new DoNotRegisterChannel(rawChannel, selector); - socketChannel.setContext(mock(SocketChannelContext.class)); - selector.scheduleForRegistration(socketChannel); - - PlainActionFuture connectFuture = PlainActionFuture.newFuture(); - socketChannel.addConnectListener(ActionListener.toBiConsumer(connectFuture)); - connectFuture.get(100, TimeUnit.SECONDS); - - assertTrue(socketChannel.isConnectComplete()); - assertTrue(socketChannel.isOpen()); - } - - @SuppressWarnings("unchecked") - public void testConnectFails() throws Exception { - SocketChannel rawChannel = mock(SocketChannel.class); - when(rawChannel.finishConnect()).thenThrow(new ConnectException()); - NioSocketChannel socketChannel = new DoNotRegisterChannel(rawChannel, selector); - socketChannel.setContext(mock(SocketChannelContext.class)); - selector.scheduleForRegistration(socketChannel); - - PlainActionFuture connectFuture = PlainActionFuture.newFuture(); - socketChannel.addConnectListener(ActionListener.toBiConsumer(connectFuture)); - ExecutionException e = expectThrows(ExecutionException.class, () -> connectFuture.get(100, TimeUnit.SECONDS)); - assertTrue(e.getCause() instanceof IOException); - - assertFalse(socketChannel.isConnectComplete()); - // Even if connection fails the channel is 'open' until close() is called - assertTrue(socketChannel.isOpen()); - } -} diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java new file mode 100644 index 00000000000..17e6b7acba2 --- /dev/null +++ b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java @@ -0,0 +1,173 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.nio; + +import org.elasticsearch.test.ESTestCase; +import org.junit.Before; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.SocketChannel; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyInt; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class SocketChannelContextTests extends ESTestCase { + + private SocketChannel rawChannel; + private TestSocketChannelContext context; + private Consumer exceptionHandler; + private NioSocketChannel channel; + + @SuppressWarnings("unchecked") + @Before + public void setup() throws Exception { + super.setUp(); + + rawChannel = mock(SocketChannel.class); + channel = mock(NioSocketChannel.class); + when(channel.getRawChannel()).thenReturn(rawChannel); + exceptionHandler = mock(Consumer.class); + context = new TestSocketChannelContext(channel, mock(SocketSelector.class), exceptionHandler); + } + + public void testIOExceptionSetIfEncountered() throws IOException { + when(rawChannel.write(any(ByteBuffer[].class), anyInt(), anyInt())).thenThrow(new IOException()); + when(rawChannel.write(any(ByteBuffer.class))).thenThrow(new IOException()); + when(rawChannel.read(any(ByteBuffer[].class), anyInt(), anyInt())).thenThrow(new IOException()); + when(rawChannel.read(any(ByteBuffer.class))).thenThrow(new IOException()); + assertFalse(context.hasIOException()); + expectThrows(IOException.class, () -> { + if (randomBoolean()) { + context.read(); + } else { + context.flushChannel(); + } + }); + assertTrue(context.hasIOException()); + } + + public void testSignalWhenPeerClosed() throws IOException { + when(rawChannel.read(any(ByteBuffer[].class), anyInt(), anyInt())).thenReturn(-1L); + when(rawChannel.read(any(ByteBuffer.class))).thenReturn(-1); + assertFalse(context.isPeerClosed()); + context.read(); + assertTrue(context.isPeerClosed()); + } + + public void testConnectSucceeds() throws IOException { + AtomicBoolean listenerCalled = new AtomicBoolean(false); + when(rawChannel.finishConnect()).thenReturn(false, true); + + context.addConnectListener((v, t) -> { + if (t == null) { + listenerCalled.compareAndSet(false, true); + } else { + throw new AssertionError("Connection should not fail"); + } + }); + + assertFalse(context.connect()); + assertFalse(context.isConnectComplete()); + assertFalse(listenerCalled.get()); + assertTrue(context.connect()); + assertTrue(context.isConnectComplete()); + assertTrue(listenerCalled.get()); + } + + public void testConnectFails() throws IOException { + AtomicReference exception = new AtomicReference<>(); + IOException ioException = new IOException("boom"); + when(rawChannel.finishConnect()).thenReturn(false).thenThrow(ioException); + + context.addConnectListener((v, t) -> { + if (t == null) { + throw new AssertionError("Connection should not succeed"); + } else { + exception.set((Exception) t); + } + }); + + assertFalse(context.connect()); + assertFalse(context.isConnectComplete()); + assertNull(exception.get()); + expectThrows(IOException.class, context::connect); + assertFalse(context.isConnectComplete()); + assertSame(ioException, exception.get()); + } + + private static class TestSocketChannelContext extends SocketChannelContext { + + private TestSocketChannelContext(NioSocketChannel channel, SocketSelector selector, Consumer exceptionHandler) { + super(channel, selector, exceptionHandler); + } + + @Override + public int read() throws IOException { + if (randomBoolean()) { + ByteBuffer[] byteBuffers = {ByteBuffer.allocate(10)}; + return readFromChannel(byteBuffers); + } else { + return readFromChannel(ByteBuffer.allocate(10)); + } + } + + @Override + public void sendMessage(ByteBuffer[] buffers, BiConsumer listener) { + + } + + @Override + public void queueWriteOperation(WriteOperation writeOperation) { + + } + + @Override + public void flushChannel() throws IOException { + if (randomBoolean()) { + ByteBuffer[] byteBuffers = {ByteBuffer.allocate(10)}; + flushToChannel(byteBuffers); + } else { + flushToChannel(ByteBuffer.allocate(10)); + } + } + + @Override + public boolean hasQueuedWriteOps() { + return false; + } + + @Override + public boolean selectorShouldClose() { + return false; + } + + @Override + public void closeChannel() { + + } + } +} diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketEventHandlerTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketEventHandlerTests.java index d74214636db..4f476c1ff6b 100644 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketEventHandlerTests.java +++ b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketEventHandlerTests.java @@ -23,12 +23,10 @@ import org.elasticsearch.test.ESTestCase; import org.junit.Before; import java.io.IOException; -import java.nio.ByteBuffer; import java.nio.channels.CancelledKeyException; import java.nio.channels.SelectionKey; import java.nio.channels.SocketChannel; -import java.util.function.BiConsumer; -import java.util.function.Supplier; +import java.util.function.Consumer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -37,137 +35,171 @@ import static org.mockito.Mockito.when; public class SocketEventHandlerTests extends ESTestCase { - private BiConsumer exceptionHandler; + private Consumer exceptionHandler; private SocketEventHandler handler; private NioSocketChannel channel; private SocketChannel rawChannel; + private DoNotRegisterContext context; @Before @SuppressWarnings("unchecked") public void setUpHandler() throws IOException { - exceptionHandler = mock(BiConsumer.class); - SocketSelector socketSelector = mock(SocketSelector.class); + exceptionHandler = mock(Consumer.class); + SocketSelector selector = mock(SocketSelector.class); handler = new SocketEventHandler(logger); rawChannel = mock(SocketChannel.class); - channel = new DoNotRegisterChannel(rawChannel, socketSelector); + channel = new NioSocketChannel(rawChannel); when(rawChannel.finishConnect()).thenReturn(true); - InboundChannelBuffer buffer = InboundChannelBuffer.allocatingInstance(); - channel.setContext(new BytesChannelContext(channel, exceptionHandler, mock(SocketChannelContext.ReadConsumer.class), buffer)); - channel.register(); - channel.finishConnect(); + context = new DoNotRegisterContext(channel, selector, exceptionHandler, new TestSelectionKey(0)); + channel.setContext(context); + handler.handleRegistration(context); - when(socketSelector.isOnCurrentThread()).thenReturn(true); + when(selector.isOnCurrentThread()).thenReturn(true); } public void testRegisterCallsContext() throws IOException { NioSocketChannel channel = mock(NioSocketChannel.class); SocketChannelContext channelContext = mock(SocketChannelContext.class); when(channel.getContext()).thenReturn(channelContext); - when(channel.getSelectionKey()).thenReturn(new TestSelectionKey(0)); - handler.handleRegistration(channel); - verify(channelContext).channelRegistered(); + when(channelContext.getSelectionKey()).thenReturn(new TestSelectionKey(0)); + handler.handleRegistration(channelContext); + verify(channelContext).register(); } public void testRegisterAddsOP_CONNECTAndOP_READInterest() throws IOException { - handler.handleRegistration(channel); - assertEquals(SelectionKey.OP_READ | SelectionKey.OP_CONNECT, channel.getSelectionKey().interestOps()); + SocketChannelContext context = mock(SocketChannelContext.class); + when(context.getSelectionKey()).thenReturn(new TestSelectionKey(0)); + handler.handleRegistration(context); + assertEquals(SelectionKey.OP_READ | SelectionKey.OP_CONNECT, context.getSelectionKey().interestOps()); + } + + public void testRegisterAddsAttachment() throws IOException { + SocketChannelContext context = mock(SocketChannelContext.class); + when(context.getSelectionKey()).thenReturn(new TestSelectionKey(0)); + handler.handleRegistration(context); + assertEquals(context, context.getSelectionKey().attachment()); } public void testRegisterWithPendingWritesAddsOP_CONNECTAndOP_READAndOP_WRITEInterest() throws IOException { channel.getContext().queueWriteOperation(mock(BytesWriteOperation.class)); - handler.handleRegistration(channel); - assertEquals(SelectionKey.OP_READ | SelectionKey.OP_CONNECT | SelectionKey.OP_WRITE, channel.getSelectionKey().interestOps()); + handler.handleRegistration(context); + assertEquals(SelectionKey.OP_READ | SelectionKey.OP_CONNECT | SelectionKey.OP_WRITE, context.getSelectionKey().interestOps()); } public void testRegistrationExceptionCallsExceptionHandler() throws IOException { CancelledKeyException exception = new CancelledKeyException(); - handler.registrationException(channel, exception); - verify(exceptionHandler).accept(channel, exception); + handler.registrationException(context, exception); + verify(exceptionHandler).accept(exception); } - public void testConnectRemovesOP_CONNECTInterest() throws IOException { - SelectionKeyUtils.setConnectAndReadInterested(channel); - handler.handleConnect(channel); - assertEquals(SelectionKey.OP_READ, channel.getSelectionKey().interestOps()); + public void testConnectDoesNotRemoveOP_CONNECTInterestIfIncomplete() throws IOException { + SelectionKeyUtils.setConnectAndReadInterested(context.getSelectionKey()); + handler.handleConnect(context); + assertEquals(SelectionKey.OP_READ, context.getSelectionKey().interestOps()); + } + + public void testConnectRemovesOP_CONNECTInterestIfComplete() throws IOException { + SelectionKeyUtils.setConnectAndReadInterested(context.getSelectionKey()); + handler.handleConnect(context); + assertEquals(SelectionKey.OP_READ, context.getSelectionKey().interestOps()); } public void testConnectExceptionCallsExceptionHandler() throws IOException { IOException exception = new IOException(); - handler.connectException(channel, exception); - verify(exceptionHandler).accept(channel, exception); + handler.connectException(context, exception); + verify(exceptionHandler).accept(exception); } public void testHandleReadDelegatesToContext() throws IOException { - NioSocketChannel channel = new DoNotRegisterChannel(rawChannel, mock(SocketSelector.class)); + NioSocketChannel channel = new NioSocketChannel(rawChannel); SocketChannelContext context = mock(SocketChannelContext.class); channel.setContext(context); when(context.read()).thenReturn(1); - handler.handleRead(channel); + handler.handleRead(context); verify(context).read(); } public void testReadExceptionCallsExceptionHandler() { IOException exception = new IOException(); - handler.readException(channel, exception); - verify(exceptionHandler).accept(channel, exception); + handler.readException(context, exception); + verify(exceptionHandler).accept(exception); } public void testWriteExceptionCallsExceptionHandler() { IOException exception = new IOException(); - handler.writeException(channel, exception); - verify(exceptionHandler).accept(channel, exception); + handler.writeException(context, exception); + verify(exceptionHandler).accept(exception); } public void testPostHandlingCallWillCloseTheChannelIfReady() throws IOException { NioSocketChannel channel = mock(NioSocketChannel.class); SocketChannelContext context = mock(SocketChannelContext.class); - when(channel.getSelectionKey()).thenReturn(new TestSelectionKey(0)); when(channel.getContext()).thenReturn(context); when(context.selectorShouldClose()).thenReturn(true); - handler.postHandling(channel); + handler.postHandling(context); verify(context).closeFromSelector(); } public void testPostHandlingCallWillNotCloseTheChannelIfNotReady() throws IOException { - NioSocketChannel channel = mock(NioSocketChannel.class); SocketChannelContext context = mock(SocketChannelContext.class); - when(channel.getSelectionKey()).thenReturn(new TestSelectionKey(0)); - - when(channel.getContext()).thenReturn(context); + when(context.getSelectionKey()).thenReturn(new TestSelectionKey(SelectionKey.OP_READ | SelectionKey.OP_WRITE)); when(context.selectorShouldClose()).thenReturn(false); - handler.postHandling(channel); - verify(channel, times(0)).closeFromSelector(); + NioSocketChannel channel = mock(NioSocketChannel.class); + when(channel.getContext()).thenReturn(context); + + handler.postHandling(context); + + verify(context, times(0)).closeFromSelector(); } public void testPostHandlingWillAddWriteIfNecessary() throws IOException { - NioSocketChannel channel = new DoNotRegisterChannel(rawChannel, mock(SocketSelector.class)); - channel.setSelectionKey(new TestSelectionKey(SelectionKey.OP_READ)); + TestSelectionKey selectionKey = new TestSelectionKey(SelectionKey.OP_READ); SocketChannelContext context = mock(SocketChannelContext.class); - channel.setContext(context); - + when(context.getSelectionKey()).thenReturn(selectionKey); when(context.hasQueuedWriteOps()).thenReturn(true); - assertEquals(SelectionKey.OP_READ, channel.getSelectionKey().interestOps()); - handler.postHandling(channel); - assertEquals(SelectionKey.OP_READ | SelectionKey.OP_WRITE, channel.getSelectionKey().interestOps()); + NioSocketChannel channel = mock(NioSocketChannel.class); + when(channel.getContext()).thenReturn(context); + + assertEquals(SelectionKey.OP_READ, selectionKey.interestOps()); + handler.postHandling(context); + assertEquals(SelectionKey.OP_READ | SelectionKey.OP_WRITE, selectionKey.interestOps()); } public void testPostHandlingWillRemoveWriteIfNecessary() throws IOException { - NioSocketChannel channel = new DoNotRegisterChannel(rawChannel, mock(SocketSelector.class)); - channel.setSelectionKey(new TestSelectionKey(SelectionKey.OP_READ | SelectionKey.OP_WRITE)); + TestSelectionKey key = new TestSelectionKey(SelectionKey.OP_READ | SelectionKey.OP_WRITE); SocketChannelContext context = mock(SocketChannelContext.class); - channel.setContext(context); - + when(context.getSelectionKey()).thenReturn(key); when(context.hasQueuedWriteOps()).thenReturn(false); - assertEquals(SelectionKey.OP_READ | SelectionKey.OP_WRITE, channel.getSelectionKey().interestOps()); - handler.postHandling(channel); - assertEquals(SelectionKey.OP_READ, channel.getSelectionKey().interestOps()); + NioSocketChannel channel = mock(NioSocketChannel.class); + when(channel.getContext()).thenReturn(context); + + + assertEquals(SelectionKey.OP_READ | SelectionKey.OP_WRITE, key.interestOps()); + handler.postHandling(context); + assertEquals(SelectionKey.OP_READ, key.interestOps()); + } + + private class DoNotRegisterContext extends BytesChannelContext { + + private final TestSelectionKey selectionKey; + + DoNotRegisterContext(NioSocketChannel channel, SocketSelector selector, Consumer exceptionHandler, + TestSelectionKey selectionKey) { + super(channel, selector, exceptionHandler, mock(ReadConsumer.class), InboundChannelBuffer.allocatingInstance()); + this.selectionKey = selectionKey; + } + + @Override + public void register() { + setSelectionKey(selectionKey); + } } } diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketSelectorTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketSelectorTests.java index 5992244b2f9..223f14455f9 100644 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketSelectorTests.java +++ b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketSelectorTests.java @@ -58,22 +58,22 @@ public class SocketSelectorTests extends ESTestCase { @SuppressWarnings("unchecked") public void setUp() throws Exception { super.setUp(); + rawSelector = mock(Selector.class); eventHandler = mock(SocketEventHandler.class); channel = mock(NioSocketChannel.class); channelContext = mock(SocketChannelContext.class); listener = mock(BiConsumer.class); selectionKey = new TestSelectionKey(0); - selectionKey.attach(channel); - rawSelector = mock(Selector.class); + selectionKey.attach(channelContext); this.socketSelector = new SocketSelector(eventHandler, rawSelector); this.socketSelector.setThread(); - when(channel.isOpen()).thenReturn(true); - when(channel.getSelectionKey()).thenReturn(selectionKey); when(channel.getContext()).thenReturn(channelContext); - when(channel.isConnectComplete()).thenReturn(true); - when(channel.getSelector()).thenReturn(socketSelector); + when(channelContext.isOpen()).thenReturn(true); + when(channelContext.getSelector()).thenReturn(socketSelector); + when(channelContext.getSelectionKey()).thenReturn(selectionKey); + when(channelContext.isConnectComplete()).thenReturn(true); } public void testRegisterChannel() throws Exception { @@ -81,64 +81,52 @@ public class SocketSelectorTests extends ESTestCase { socketSelector.preSelect(); - verify(eventHandler).handleRegistration(channel); + verify(eventHandler).handleRegistration(channelContext); } public void testClosedChannelWillNotBeRegistered() throws Exception { - when(channel.isOpen()).thenReturn(false); + when(channelContext.isOpen()).thenReturn(false); socketSelector.scheduleForRegistration(channel); socketSelector.preSelect(); - verify(eventHandler).registrationException(same(channel), any(ClosedChannelException.class)); - verify(channel, times(0)).finishConnect(); + verify(eventHandler).registrationException(same(channelContext), any(ClosedChannelException.class)); + verify(eventHandler, times(0)).handleConnect(channelContext); } public void testRegisterChannelFailsDueToException() throws Exception { socketSelector.scheduleForRegistration(channel); ClosedChannelException closedChannelException = new ClosedChannelException(); - doThrow(closedChannelException).when(channel).register(); + doThrow(closedChannelException).when(eventHandler).handleRegistration(channelContext); socketSelector.preSelect(); - verify(eventHandler).registrationException(channel, closedChannelException); - verify(channel, times(0)).finishConnect(); + verify(eventHandler).registrationException(channelContext, closedChannelException); + verify(eventHandler, times(0)).handleConnect(channelContext); } - public void testSuccessfullyRegisterChannelWillConnect() throws Exception { + public void testSuccessfullyRegisterChannelWillAttemptConnect() throws Exception { socketSelector.scheduleForRegistration(channel); - when(channel.finishConnect()).thenReturn(true); - socketSelector.preSelect(); - verify(eventHandler).handleConnect(channel); - } - - public void testConnectIncompleteWillNotNotify() throws Exception { - socketSelector.scheduleForRegistration(channel); - - when(channel.finishConnect()).thenReturn(false); - - socketSelector.preSelect(); - - verify(eventHandler, times(0)).handleConnect(channel); + verify(eventHandler).handleConnect(channelContext); } public void testQueueWriteWhenNotRunning() throws Exception { socketSelector.close(); - socketSelector.queueWrite(new BytesWriteOperation(channel, buffers, listener)); + socketSelector.queueWrite(new BytesWriteOperation(channelContext, buffers, listener)); verify(listener).accept(isNull(Void.class), any(ClosedSelectorException.class)); } public void testQueueWriteChannelIsClosed() throws Exception { - BytesWriteOperation writeOperation = new BytesWriteOperation(channel, buffers, listener); + BytesWriteOperation writeOperation = new BytesWriteOperation(channelContext, buffers, listener); socketSelector.queueWrite(writeOperation); - when(channel.isOpen()).thenReturn(false); + when(channelContext.isOpen()).thenReturn(false); socketSelector.preSelect(); verify(channelContext, times(0)).queueWriteOperation(writeOperation); @@ -148,11 +136,11 @@ public class SocketSelectorTests extends ESTestCase { public void testQueueWriteSelectionKeyThrowsException() throws Exception { SelectionKey selectionKey = mock(SelectionKey.class); - BytesWriteOperation writeOperation = new BytesWriteOperation(channel, buffers, listener); + BytesWriteOperation writeOperation = new BytesWriteOperation(channelContext, buffers, listener); CancelledKeyException cancelledKeyException = new CancelledKeyException(); socketSelector.queueWrite(writeOperation); - when(channel.getSelectionKey()).thenReturn(selectionKey); + when(channelContext.getSelectionKey()).thenReturn(selectionKey); when(selectionKey.interestOps(anyInt())).thenThrow(cancelledKeyException); socketSelector.preSelect(); @@ -161,7 +149,7 @@ public class SocketSelectorTests extends ESTestCase { } public void testQueueWriteSuccessful() throws Exception { - BytesWriteOperation writeOperation = new BytesWriteOperation(channel, buffers, listener); + BytesWriteOperation writeOperation = new BytesWriteOperation(channelContext, buffers, listener); socketSelector.queueWrite(writeOperation); assertTrue((selectionKey.interestOps() & SelectionKey.OP_WRITE) == 0); @@ -173,7 +161,7 @@ public class SocketSelectorTests extends ESTestCase { } public void testQueueDirectlyInChannelBufferSuccessful() throws Exception { - BytesWriteOperation writeOperation = new BytesWriteOperation(channel, buffers, listener); + BytesWriteOperation writeOperation = new BytesWriteOperation(channelContext, buffers, listener); assertTrue((selectionKey.interestOps() & SelectionKey.OP_WRITE) == 0); @@ -186,10 +174,10 @@ public class SocketSelectorTests extends ESTestCase { public void testQueueDirectlyInChannelBufferSelectionKeyThrowsException() throws Exception { SelectionKey selectionKey = mock(SelectionKey.class); - BytesWriteOperation writeOperation = new BytesWriteOperation(channel, buffers, listener); + BytesWriteOperation writeOperation = new BytesWriteOperation(channelContext, buffers, listener); CancelledKeyException cancelledKeyException = new CancelledKeyException(); - when(channel.getSelectionKey()).thenReturn(selectionKey); + when(channelContext.getSelectionKey()).thenReturn(selectionKey); when(selectionKey.interestOps(anyInt())).thenThrow(cancelledKeyException); socketSelector.queueWriteInChannelBuffer(writeOperation); @@ -200,19 +188,9 @@ public class SocketSelectorTests extends ESTestCase { public void testConnectEvent() throws Exception { selectionKey.setReadyOps(SelectionKey.OP_CONNECT); - when(channel.finishConnect()).thenReturn(true); socketSelector.processKey(selectionKey); - verify(eventHandler).handleConnect(channel); - } - - public void testConnectEventFinishUnsuccessful() throws Exception { - selectionKey.setReadyOps(SelectionKey.OP_CONNECT); - - when(channel.finishConnect()).thenReturn(false); - socketSelector.processKey(selectionKey); - - verify(eventHandler, times(0)).handleConnect(channel); + verify(eventHandler).handleConnect(channelContext); } public void testConnectEventFinishThrowException() throws Exception { @@ -220,11 +198,10 @@ public class SocketSelectorTests extends ESTestCase { selectionKey.setReadyOps(SelectionKey.OP_CONNECT); - when(channel.finishConnect()).thenThrow(ioException); + doThrow(ioException).when(eventHandler).handleConnect(channelContext); socketSelector.processKey(selectionKey); - verify(eventHandler, times(0)).handleConnect(channel); - verify(eventHandler).connectException(channel, ioException); + verify(eventHandler).connectException(channelContext, ioException); } public void testWillNotConsiderWriteOrReadUntilConnectionComplete() throws Exception { @@ -232,13 +209,13 @@ public class SocketSelectorTests extends ESTestCase { selectionKey.setReadyOps(SelectionKey.OP_WRITE | SelectionKey.OP_READ); - doThrow(ioException).when(eventHandler).handleWrite(channel); + doThrow(ioException).when(eventHandler).handleWrite(channelContext); - when(channel.isConnectComplete()).thenReturn(false); + when(channelContext.isConnectComplete()).thenReturn(false); socketSelector.processKey(selectionKey); - verify(eventHandler, times(0)).handleWrite(channel); - verify(eventHandler, times(0)).handleRead(channel); + verify(eventHandler, times(0)).handleWrite(channelContext); + verify(eventHandler, times(0)).handleRead(channelContext); } public void testSuccessfulWriteEvent() throws Exception { @@ -246,7 +223,7 @@ public class SocketSelectorTests extends ESTestCase { socketSelector.processKey(selectionKey); - verify(eventHandler).handleWrite(channel); + verify(eventHandler).handleWrite(channelContext); } public void testWriteEventWithException() throws Exception { @@ -254,11 +231,11 @@ public class SocketSelectorTests extends ESTestCase { selectionKey.setReadyOps(SelectionKey.OP_WRITE); - doThrow(ioException).when(eventHandler).handleWrite(channel); + doThrow(ioException).when(eventHandler).handleWrite(channelContext); socketSelector.processKey(selectionKey); - verify(eventHandler).writeException(channel, ioException); + verify(eventHandler).writeException(channelContext, ioException); } public void testSuccessfulReadEvent() throws Exception { @@ -266,7 +243,7 @@ public class SocketSelectorTests extends ESTestCase { socketSelector.processKey(selectionKey); - verify(eventHandler).handleRead(channel); + verify(eventHandler).handleRead(channelContext); } public void testReadEventWithException() throws Exception { @@ -274,11 +251,11 @@ public class SocketSelectorTests extends ESTestCase { selectionKey.setReadyOps(SelectionKey.OP_READ); - doThrow(ioException).when(eventHandler).handleRead(channel); + doThrow(ioException).when(eventHandler).handleRead(channelContext); socketSelector.processKey(selectionKey); - verify(eventHandler).readException(channel, ioException); + verify(eventHandler).readException(channelContext, ioException); } public void testWillCallPostHandleAfterChannelHandling() throws Exception { @@ -286,30 +263,32 @@ public class SocketSelectorTests extends ESTestCase { socketSelector.processKey(selectionKey); - verify(eventHandler).handleWrite(channel); - verify(eventHandler).handleRead(channel); - verify(eventHandler).postHandling(channel); + verify(eventHandler).handleWrite(channelContext); + verify(eventHandler).handleRead(channelContext); + verify(eventHandler).postHandling(channelContext); } public void testCleanup() throws Exception { - NioSocketChannel unRegisteredChannel = mock(NioSocketChannel.class); + NioSocketChannel unregisteredChannel = mock(NioSocketChannel.class); + SocketChannelContext unregisteredContext = mock(SocketChannelContext.class); + when(unregisteredChannel.getContext()).thenReturn(unregisteredContext); socketSelector.scheduleForRegistration(channel); socketSelector.preSelect(); - socketSelector.queueWrite(new BytesWriteOperation(mock(NioSocketChannel.class), buffers, listener)); - socketSelector.scheduleForRegistration(unRegisteredChannel); + socketSelector.queueWrite(new BytesWriteOperation(channelContext, buffers, listener)); + socketSelector.scheduleForRegistration(unregisteredChannel); TestSelectionKey testSelectionKey = new TestSelectionKey(0); - testSelectionKey.attach(channel); + testSelectionKey.attach(channelContext); when(rawSelector.keys()).thenReturn(new HashSet<>(Collections.singletonList(testSelectionKey))); socketSelector.cleanupAndCloseChannels(); verify(listener).accept(isNull(Void.class), any(ClosedSelectorException.class)); - verify(eventHandler).handleClose(channel); - verify(eventHandler).handleClose(unRegisteredChannel); + verify(eventHandler).handleClose(channelContext); + verify(eventHandler).handleClose(unregisteredContext); } public void testExecuteListenerWillHandleException() throws Exception { diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java index acea1ca5d48..eb3d7f3d710 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java @@ -31,15 +31,15 @@ import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.PageCacheRecycler; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.indices.breaker.CircuitBreakerService; -import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.nio.AcceptingSelector; import org.elasticsearch.nio.AcceptorEventHandler; +import org.elasticsearch.nio.BytesChannelContext; import org.elasticsearch.nio.ChannelFactory; import org.elasticsearch.nio.InboundChannelBuffer; -import org.elasticsearch.nio.BytesChannelContext; import org.elasticsearch.nio.NioGroup; import org.elasticsearch.nio.NioSocketChannel; import org.elasticsearch.nio.ServerChannelContext; +import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.nio.SocketEventHandler; import org.elasticsearch.nio.SocketSelector; import org.elasticsearch.threadpool.ThreadPool; @@ -53,7 +53,7 @@ import java.nio.ByteBuffer; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; import java.util.concurrent.ConcurrentMap; -import java.util.function.BiConsumer; +import java.util.function.Consumer; import java.util.function.Supplier; import static org.elasticsearch.common.settings.Setting.intSetting; @@ -179,15 +179,15 @@ public class NioTransport extends TcpTransport { @Override public TcpNioSocketChannel createChannel(SocketSelector selector, SocketChannel channel) throws IOException { - TcpNioSocketChannel nioChannel = new TcpNioSocketChannel(profileName, channel, selector); + TcpNioSocketChannel nioChannel = new TcpNioSocketChannel(profileName, channel); Supplier pageSupplier = () -> { Recycler.V bytes = pageCacheRecycler.bytePage(false); return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close); }; SocketChannelContext.ReadConsumer nioReadConsumer = channelBuffer -> consumeNetworkReads(nioChannel, BytesReference.fromByteBuffers(channelBuffer.sliceBuffersTo(channelBuffer.getIndex()))); - BiConsumer exceptionHandler = NioTransport.this::exceptionCaught; - BytesChannelContext context = new BytesChannelContext(nioChannel, exceptionHandler, nioReadConsumer, + Consumer exceptionHandler = (e) -> exceptionCaught(nioChannel, e); + BytesChannelContext context = new BytesChannelContext(nioChannel, selector, exceptionHandler, nioReadConsumer, new InboundChannelBuffer(pageSupplier)); nioChannel.setContext(context); return nioChannel; @@ -195,8 +195,9 @@ public class NioTransport extends TcpTransport { @Override public TcpNioServerSocketChannel createServerChannel(AcceptingSelector selector, ServerSocketChannel channel) throws IOException { - TcpNioServerSocketChannel nioChannel = new TcpNioServerSocketChannel(profileName, channel, this, selector); - ServerChannelContext context = new ServerChannelContext(nioChannel, NioTransport.this::acceptChannel, (c, e) -> {}); + TcpNioServerSocketChannel nioChannel = new TcpNioServerSocketChannel(profileName, channel); + ServerChannelContext context = new ServerChannelContext(nioChannel, this, selector, NioTransport.this::acceptChannel, + (e) -> {}); nioChannel.setContext(context); return nioChannel; } diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpNioServerSocketChannel.java b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpNioServerSocketChannel.java index 683ae146cfb..c63acc9f4de 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpNioServerSocketChannel.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpNioServerSocketChannel.java @@ -38,10 +38,8 @@ public class TcpNioServerSocketChannel extends NioServerSocketChannel implements private final String profile; - public TcpNioServerSocketChannel(String profile, ServerSocketChannel socketChannel, - ChannelFactory channelFactory, - AcceptingSelector selector) throws IOException { - super(socketChannel, channelFactory, selector); + public TcpNioServerSocketChannel(String profile, ServerSocketChannel socketChannel) throws IOException { + super(socketChannel); this.profile = profile; } @@ -62,7 +60,7 @@ public class TcpNioServerSocketChannel extends NioServerSocketChannel implements @Override public void close() { - getSelector().queueChannelClose(this); + getContext().closeChannel(); } @Override diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpNioSocketChannel.java b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpNioSocketChannel.java index c2064e53ca6..44ab17457e8 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpNioSocketChannel.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpNioSocketChannel.java @@ -33,8 +33,8 @@ public class TcpNioSocketChannel extends NioSocketChannel implements TcpChannel private final String profile; - public TcpNioSocketChannel(String profile, SocketChannel socketChannel, SocketSelector selector) throws IOException { - super(socketChannel, selector); + public TcpNioSocketChannel(String profile, SocketChannel socketChannel) throws IOException { + super(socketChannel); this.profile = profile; } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java index ec262261e54..5271ac6a148 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java @@ -164,8 +164,8 @@ public class MockNioTransport extends TcpTransport { }; SocketChannelContext.ReadConsumer nioReadConsumer = channelBuffer -> consumeNetworkReads(nioChannel, BytesReference.fromByteBuffers(channelBuffer.sliceBuffersTo(channelBuffer.getIndex()))); - BytesChannelContext context = new BytesChannelContext(nioChannel, MockNioTransport.this::exceptionCaught, nioReadConsumer, - new InboundChannelBuffer(pageSupplier)); + BytesChannelContext context = new BytesChannelContext(nioChannel, selector, (e) -> exceptionCaught(nioChannel, e), + nioReadConsumer, new InboundChannelBuffer(pageSupplier)); nioChannel.setContext(context); return nioChannel; } @@ -173,7 +173,8 @@ public class MockNioTransport extends TcpTransport { @Override public MockServerChannel createServerChannel(AcceptingSelector selector, ServerSocketChannel channel) throws IOException { MockServerChannel nioServerChannel = new MockServerChannel(profileName, channel, this, selector); - ServerChannelContext context = new ServerChannelContext(nioServerChannel, MockNioTransport.this::acceptChannel, (c, e) -> {}); + ServerChannelContext context = new ServerChannelContext(nioServerChannel, this, selector, MockNioTransport.this::acceptChannel, + (e) -> {}); nioServerChannel.setContext(context); return nioServerChannel; } @@ -185,13 +186,13 @@ public class MockNioTransport extends TcpTransport { MockServerChannel(String profile, ServerSocketChannel channel, ChannelFactory channelFactory, AcceptingSelector selector) throws IOException { - super(channel, channelFactory, selector); + super(channel); this.profile = profile; } @Override public void close() { - getSelector().queueChannelClose(this); + getContext().closeChannel(); } @Override @@ -226,7 +227,7 @@ public class MockNioTransport extends TcpTransport { private MockSocketChannel(String profile, java.nio.channels.SocketChannel socketChannel, SocketSelector selector) throws IOException { - super(socketChannel, selector); + super(socketChannel); this.profile = profile; } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/TestingSocketEventHandler.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/TestingSocketEventHandler.java index ecc00c24f9c..2e2d8aa5ada 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/TestingSocketEventHandler.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/TestingSocketEventHandler.java @@ -20,7 +20,7 @@ package org.elasticsearch.transport.nio; import org.apache.logging.log4j.Logger; -import org.elasticsearch.nio.NioSocketChannel; +import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.nio.SocketEventHandler; import java.io.IOException; @@ -34,36 +34,21 @@ public class TestingSocketEventHandler extends SocketEventHandler { super(logger); } - private Set hasConnectedMap = Collections.newSetFromMap(new WeakHashMap<>()); + private Set hasConnectedMap = Collections.newSetFromMap(new WeakHashMap<>()); - public void handleConnect(NioSocketChannel channel) { - assert hasConnectedMap.contains(channel) == false : "handleConnect should only be called once per channel"; - hasConnectedMap.add(channel); - super.handleConnect(channel); + public void handleConnect(SocketChannelContext context) throws IOException { + assert hasConnectedMap.contains(context) == false : "handleConnect should only be called is a channel is not yet connected"; + super.handleConnect(context); + if (context.isConnectComplete()) { + hasConnectedMap.add(context); + } } - private Set hasConnectExceptionMap = Collections.newSetFromMap(new WeakHashMap<>()); + private Set hasConnectExceptionMap = Collections.newSetFromMap(new WeakHashMap<>()); - public void connectException(NioSocketChannel channel, Exception e) { - assert hasConnectExceptionMap.contains(channel) == false : "connectException should only called at maximum once per channel"; - hasConnectExceptionMap.add(channel); - super.connectException(channel, e); + public void connectException(SocketChannelContext context, Exception e) { + assert hasConnectExceptionMap.contains(context) == false : "connectException should only called at maximum once per channel"; + hasConnectExceptionMap.add(context); + super.connectException(context, e); } - - public void handleRead(NioSocketChannel channel) throws IOException { - super.handleRead(channel); - } - - public void readException(NioSocketChannel channel, Exception e) { - super.readException(channel, e); - } - - public void handleWrite(NioSocketChannel channel) throws IOException { - super.handleWrite(channel); - } - - public void writeException(NioSocketChannel channel, Exception e) { - super.writeException(channel, e); - } - }