From c7a7c69b2bc1013a9b624052f24f63b37759b42d Mon Sep 17 00:00:00 2001 From: Tim Brooks Date: Fri, 21 Jul 2017 11:55:23 -0500 Subject: [PATCH] Simplify NioChannel creation and closing process (#25504) Currently an NioChannel is created and it is UNREGISTERED. At some point it is registered with a selector. From that point on, the channel can only be closed by the selector. The fact that a channel might not be associated with a selector has significant implications for concurrency and the channel shutdown process. The only thing that is simplified by allowing channels to be in a state independent of a selector is some testing scenarios. This PR modifies channels so that they are given a selector at creation time and are always associated with that selector. Only that selector can close that channel. This simplifies the channel lifecycle and closing intricacies. --- .../transport/nio/AcceptingSelector.java | 25 +-- .../transport/nio/AcceptorEventHandler.java | 15 +- .../transport/nio/ESSelector.java | 60 ++++--- .../transport/nio/EventHandler.java | 13 +- .../transport/nio/NioClient.java | 10 +- .../transport/nio/NioTransport.java | 46 +++--- .../transport/nio/OpenChannels.java | 29 ++-- .../transport/nio/SocketSelector.java | 23 +-- .../nio/channel/AbstractNioChannel.java | 106 +++++-------- .../transport/nio/channel/ChannelFactory.java | 133 +++++++++++----- .../transport/nio/channel/CloseFuture.java | 6 +- .../transport/nio/channel/NioChannel.java | 2 +- .../nio/channel/NioServerSocketChannel.java | 7 +- .../nio/channel/NioSocketChannel.java | 40 ++--- .../transport/nio/AcceptingSelectorTests.java | 43 ++++- .../nio/AcceptorEventHandlerTests.java | 40 +++-- .../transport/nio/ESSelectorTests.java | 6 +- .../transport/nio/NioClientTests.java | 26 +-- .../nio/SimpleNioTransportTests.java | 4 +- .../nio/SocketEventHandlerTests.java | 4 +- .../transport/nio/SocketSelectorTests.java | 34 ++-- .../channel/AbstractNioChannelTestCase.java | 101 ------------ .../nio/channel/ChannelFactoryTests.java | 150 ++++++++++++++++++ .../nio/channel/DoNotRegisterChannel.java | 15 +- .../channel/DoNotRegisterServerChannel.java | 15 +- .../channel/NioServerSocketChannelTests.java | 78 ++++++++- .../nio/channel/NioSocketChannelTests.java | 125 ++++++++++----- .../nio/channel/TcpWriteContextTests.java | 99 ------------ 28 files changed, 704 insertions(+), 551 deletions(-) delete mode 100644 test/framework/src/test/java/org/elasticsearch/transport/nio/channel/AbstractNioChannelTestCase.java create mode 100644 test/framework/src/test/java/org/elasticsearch/transport/nio/channel/ChannelFactoryTests.java diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/AcceptingSelector.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/AcceptingSelector.java index ec5c9a963de..f43a0615005 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/AcceptingSelector.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/AcceptingSelector.java @@ -68,16 +68,15 @@ public class AcceptingSelector extends ESSelector { @Override void cleanup() { - channelsToClose.addAll(registeredChannels); - closePendingChannels(); + channelsToClose.addAll(newChannels); } /** - * Registers a NioServerSocketChannel to be handled by this selector. The channel will by queued and + * Schedules a NioServerSocketChannel to be registered with this selector. The channel will by queued and * eventually registered next time through the event loop. * @param serverSocketChannel the channel to register */ - public void registerServerChannel(NioServerSocketChannel serverSocketChannel) { + public void scheduleForRegistration(NioServerSocketChannel serverSocketChannel) { newChannels.add(serverSocketChannel); ensureSelectorOpenForEnqueuing(newChannels, serverSocketChannel); wakeup(); @@ -86,11 +85,19 @@ public class AcceptingSelector extends ESSelector { private void setUpNewServerChannels() throws ClosedChannelException { NioServerSocketChannel newChannel; while ((newChannel = this.newChannels.poll()) != null) { - if (newChannel.register(this)) { - SelectionKey selectionKey = newChannel.getSelectionKey(); - selectionKey.attach(newChannel); - registeredChannels.add(newChannel); - eventHandler.serverChannelRegistered(newChannel); + assert newChannel.getSelector() == this : "The channel must be registered with the selector with which it was created"; + try { + if (newChannel.isOpen()) { + newChannel.register(); + SelectionKey selectionKey = newChannel.getSelectionKey(); + selectionKey.attach(newChannel); + addRegisteredChannel(newChannel); + eventHandler.serverChannelRegistered(newChannel); + } else { + eventHandler.registrationException(newChannel, new ClosedChannelException()); + } + } catch (IOException e) { + eventHandler.registrationException(newChannel, e); } } } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/AcceptorEventHandler.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/AcceptorEventHandler.java index 7ce3b93e17c..7228cf4f050 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/AcceptorEventHandler.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/AcceptorEventHandler.java @@ -53,6 +53,16 @@ public class AcceptorEventHandler extends EventHandler { openChannels.serverChannelOpened(nioServerSocketChannel); } + /** + * This method is called when an attempt to register a server channel throws an exception. + * + * @param channel that was registered + * @param exception that occurred + */ + public void registrationException(NioServerSocketChannel channel, Exception exception) { + logger.error("failed to register server channel", exception); + } + /** * This method is called when a server channel signals it is ready to accept a connection. All of the * accept logic should occur in this call. @@ -61,10 +71,9 @@ public class AcceptorEventHandler extends EventHandler { */ public void acceptChannel(NioServerSocketChannel nioServerChannel) throws IOException { ChannelFactory channelFactory = nioServerChannel.getChannelFactory(); - NioSocketChannel nioSocketChannel = channelFactory.acceptNioChannel(nioServerChannel); + SocketSelector selector = selectorSupplier.get(); + NioSocketChannel nioSocketChannel = channelFactory.acceptNioChannel(nioServerChannel, selector, openChannels::channelClosed); openChannels.acceptedChannelOpened(nioSocketChannel); - nioSocketChannel.getCloseFuture().setListener(openChannels::channelClosed); - selectorSupplier.get().registerSocketChannel(nioSocketChannel); } /** diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/ESSelector.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/ESSelector.java index 44c3901d1ff..82cf6dafe03 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/ESSelector.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/ESSelector.java @@ -46,12 +46,12 @@ public abstract class ESSelector implements Closeable { final Selector selector; final ConcurrentLinkedQueue channelsToClose = new ConcurrentLinkedQueue<>(); - final Set registeredChannels = Collections.newSetFromMap(new ConcurrentHashMap()); private final EventHandler eventHandler; private final ReentrantLock runLock = new ReentrantLock(); private final AtomicBoolean isClosed = new AtomicBoolean(false); private final PlainActionFuture isRunningFuture = PlainActionFuture.newFuture(); + private final Set registeredChannels = Collections.newSetFromMap(new ConcurrentHashMap()); private volatile Thread thread; ESSelector(EventHandler eventHandler) throws IOException { @@ -77,9 +77,15 @@ public abstract class ESSelector implements Closeable { } } finally { try { - cleanup(); + cleanupAndCloseChannels(); } finally { - runLock.unlock(); + try { + selector.close(); + } catch (IOException e) { + eventHandler.closeSelectorException(e); + } finally { + runLock.unlock(); + } } } } else { @@ -102,6 +108,12 @@ public abstract class ESSelector implements Closeable { } } + void cleanupAndCloseChannels() { + cleanup(); + channelsToClose.addAll(registeredChannels); + closePendingChannels(); + } + /** * Should implement the specific select logic. This will be called once per {@link #singleLoop()} * @@ -111,6 +123,11 @@ public abstract class ESSelector implements Closeable { */ abstract void doSelect(int timeout) throws IOException, ClosedSelectorException; + /** + * Called once as the selector is being closed. + */ + abstract void cleanup(); + void setThread() { thread = Thread.currentThread(); } @@ -119,8 +136,8 @@ public abstract class ESSelector implements Closeable { return Thread.currentThread() == thread; } - public void wakeup() { - // TODO: Do I need the wakeup optimizations that some other libraries use? + void wakeup() { + // TODO: Do we need the wakeup optimizations that some other libraries use? selector.wakeup(); } @@ -128,6 +145,15 @@ public abstract class ESSelector implements Closeable { return registeredChannels; } + public void addRegisteredChannel(NioChannel channel) { + assert registeredChannels.contains(channel) == false : "Should only register channel once"; + registeredChannels.add(channel); + } + + public void removeRegisteredChannel(NioChannel channel) { + registeredChannels.remove(channel); + } + @Override public void close() throws IOException { close(false); @@ -135,7 +161,6 @@ public abstract class ESSelector implements Closeable { public void close(boolean shouldInterrupt) throws IOException { if (isClosed.compareAndSet(false, true)) { - selector.close(); if (shouldInterrupt && thread != null) { thread.interrupt(); } else { @@ -146,24 +171,12 @@ public abstract class ESSelector implements Closeable { } public void queueChannelClose(NioChannel channel) { + assert channel.getSelector() == this : "Must schedule a channel for closure with its selector"; channelsToClose.offer(channel); ensureSelectorOpenForEnqueuing(channelsToClose, channel); wakeup(); } - void closePendingChannels() { - NioChannel channel; - while ((channel = channelsToClose.poll()) != null) { - closeChannel(channel); - } - } - - - /** - * Called once as the selector is being closed. - */ - abstract void cleanup(); - public Selector rawSelector() { return selector; } @@ -198,18 +211,17 @@ public abstract class ESSelector implements Closeable { * @param the object type */ void ensureSelectorOpenForEnqueuing(ConcurrentLinkedQueue queue, O objectAdded) { - if (isClosed.get() && isOnCurrentThread() == false) { + if (isOpen() == false && isOnCurrentThread() == false) { if (queue.remove(objectAdded)) { throw new IllegalStateException("selector is already closed"); } } } - private void closeChannel(NioChannel channel) { - try { + private void closePendingChannels() { + NioChannel channel; + while ((channel = channelsToClose.poll()) != null) { eventHandler.handleClose(channel); - } finally { - registeredChannels.remove(channel); } } } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/EventHandler.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/EventHandler.java index 6ecf36343f7..382a6728771 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/EventHandler.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/EventHandler.java @@ -38,12 +38,21 @@ public abstract class EventHandler { /** * This method handles an IOException that was thrown during a call to {@link Selector#select(long)}. * - * @param exception that was uncaught + * @param exception the exception */ public void selectException(IOException exception) { logger.warn("io exception during select", exception); } + /** + * This method handles an IOException that was thrown during a call to {@link Selector#close()}. + * + * @param exception the exception + */ + public void closeSelectorException(IOException exception) { + logger.warn("io exception while closing selector", exception); + } + /** * This method handles an exception that was uncaught during a select loop. * @@ -65,7 +74,7 @@ public abstract class EventHandler { assert closeFuture.isDone() : "Should always be done as we are on the selector thread"; IOException closeException = closeFuture.getCloseException(); if (closeException != null) { - logger.trace("exception while closing channel", closeException); + logger.debug("exception while closing channel", closeException); } } } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/NioClient.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/NioClient.java index bc06ad0bc34..27ddca97878 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/NioClient.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/NioClient.java @@ -35,15 +35,11 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.locks.LockSupport; import java.util.function.Consumer; import java.util.function.Supplier; public class NioClient { - private static final int CLOSED = -1; - private final Logger logger; private final OpenChannels openChannels; private final Supplier selectorSupplier; @@ -72,12 +68,10 @@ public class NioClient { final InetSocketAddress address = node.getAddress().address(); try { for (int i = 0; i < channels.length; i++) { - SocketSelector socketSelector = selectorSupplier.get(); - NioSocketChannel nioSocketChannel = channelFactory.openNioChannel(address); + SocketSelector selector = selectorSupplier.get(); + NioSocketChannel nioSocketChannel = channelFactory.openNioChannel(address, selector, closeListener); openChannels.clientChannelOpened(nioSocketChannel); - nioSocketChannel.getCloseFuture().setListener(closeListener); connections.add(nioSocketChannel); - socketSelector.registerSocketChannel(nioSocketChannel); } Exception ex = null; diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/NioTransport.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/NioTransport.java index 90ebc858bbf..ad95b0baeda 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/NioTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/NioTransport.java @@ -94,9 +94,8 @@ public class NioTransport extends TcpTransport { @Override protected NioServerSocketChannel bind(String name, InetSocketAddress address) throws IOException { ChannelFactory channelFactory = this.profileToChannelFactory.get(name); - NioServerSocketChannel serverSocketChannel = channelFactory.openNioServerSocketChannel(name, address); - acceptors.get(++acceptorNumber % NioTransport.NIO_ACCEPTOR_COUNT.get(settings)).registerServerChannel(serverSocketChannel); - return serverSocketChannel; + AcceptingSelector selector = acceptors.get(++acceptorNumber % NioTransport.NIO_ACCEPTOR_COUNT.get(settings)); + return channelFactory.openNioServerSocketChannel(name, address, selector); } @Override @@ -104,9 +103,14 @@ public class NioTransport extends TcpTransport { ArrayList futures = new ArrayList<>(channels.size()); for (final NioChannel channel : channels) { if (channel != null && channel.isOpen()) { + // We do not need to wait for the close operation to complete. If the close operation fails due + // to an IOException, the selector's handler will log the exception. Additionally, in the case + // of transport shutdown, where we do want to ensure that all channels to finished closing, the + // NioShutdown class will block on close. futures.add(channel.closeAsync()); } } + if (blocking == false) { return; } @@ -173,29 +177,31 @@ public class NioTransport extends TcpTransport { AcceptingSelector acceptor = new AcceptingSelector(eventHandler); acceptors.add(acceptor); } + + client = createClient(); + + for (SocketSelector selector : socketSelectors) { + if (selector.isRunning() == false) { + ThreadFactory threadFactory = daemonThreadFactory(this.settings, TRANSPORT_WORKER_THREAD_NAME_PREFIX); + threadFactory.newThread(selector::runLoop).start(); + selector.isRunningFuture().actionGet(); + } + } + + for (AcceptingSelector acceptor : acceptors) { + if (acceptor.isRunning() == false) { + ThreadFactory threadFactory = daemonThreadFactory(this.settings, TRANSPORT_ACCEPTOR_THREAD_NAME_PREFIX); + threadFactory.newThread(acceptor::runLoop).start(); + acceptor.isRunningFuture().actionGet(); + } + } + // loop through all profiles and start them up, special handling for default one for (ProfileSettings profileSettings : profileSettings) { profileToChannelFactory.putIfAbsent(profileSettings.profileName, new ChannelFactory(profileSettings, tcpReadHandler)); bindServer(profileSettings); } } - client = createClient(); - - for (SocketSelector selector : socketSelectors) { - if (selector.isRunning() == false) { - ThreadFactory threadFactory = daemonThreadFactory(this.settings, TRANSPORT_WORKER_THREAD_NAME_PREFIX); - threadFactory.newThread(selector::runLoop).start(); - selector.isRunningFuture().actionGet(); - } - } - - for (AcceptingSelector acceptor : acceptors) { - if (acceptor.isRunning() == false) { - ThreadFactory threadFactory = daemonThreadFactory(this.settings, TRANSPORT_ACCEPTOR_THREAD_NAME_PREFIX); - threadFactory.newThread(acceptor::runLoop).start(); - acceptor.isRunningFuture().actionGet(); - } - } super.doStart(); success = true; diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/OpenChannels.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/OpenChannels.java index eea353a6c14..4655f19001d 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/OpenChannels.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/OpenChannels.java @@ -21,12 +21,14 @@ package org.elasticsearch.transport.nio; import org.apache.logging.log4j.Logger; import org.elasticsearch.common.lease.Releasable; +import org.elasticsearch.transport.nio.channel.CloseFuture; import org.elasticsearch.transport.nio.channel.NioChannel; import org.elasticsearch.transport.nio.channel.NioServerSocketChannel; import org.elasticsearch.transport.nio.channel.NioSocketChannel; +import java.util.ArrayList; import java.util.HashSet; -import java.util.Map; +import java.util.List; import java.util.concurrent.ConcurrentMap; import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentMap; @@ -90,31 +92,40 @@ public class OpenChannels implements Releasable { } public void closeServerChannels() { + List futures = new ArrayList<>(); for (NioServerSocketChannel channel : openServerChannels.keySet()) { - ensureClosedInternal(channel); + CloseFuture closeFuture = channel.closeAsync(); + futures.add(closeFuture); } + ensureChannelsClosed(futures); openServerChannels.clear(); } @Override public void close() { + List futures = new ArrayList<>(); for (NioSocketChannel channel : openClientChannels.keySet()) { - ensureClosedInternal(channel); + CloseFuture closeFuture = channel.closeAsync(); + futures.add(closeFuture); } for (NioSocketChannel channel : openAcceptedChannels.keySet()) { - ensureClosedInternal(channel); + CloseFuture closeFuture = channel.closeAsync(); + futures.add(closeFuture); } + ensureChannelsClosed(futures); openClientChannels.clear(); openAcceptedChannels.clear(); } - private void ensureClosedInternal(NioChannel channel) { - try { - channel.closeAsync().get(); - } catch (Exception e) { - logger.trace("exception while closing channels", e); + private void ensureChannelsClosed(List futures) { + for (CloseFuture future : futures) { + try { + future.get(); + } catch (Exception e) { + logger.debug("exception while closing channels", e); + } } } } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/SocketSelector.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/SocketSelector.java index ac40afe9bcc..b4da075f0fc 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/SocketSelector.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/SocketSelector.java @@ -63,7 +63,6 @@ public class SocketSelector extends ESSelector { Set selectionKeys = selector.selectedKeys(); processKeys(selectionKeys); } - } @Override @@ -73,16 +72,14 @@ public class SocketSelector extends ESSelector { op.getListener().onFailure(new ClosedSelectorException()); } channelsToClose.addAll(newChannels); - channelsToClose.addAll(registeredChannels); - closePendingChannels(); } /** - * Registers a NioSocketChannel to be handled by this selector. The channel will by queued and eventually + * Schedules a NioSocketChannel to be registered by this selector. The channel will by queued and eventually * registered next time through the event loop. * @param nioSocketChannel the channel to register */ - public void registerSocketChannel(NioSocketChannel nioSocketChannel) { + public void scheduleForRegistration(NioSocketChannel nioSocketChannel) { newChannels.offer(nioSocketChannel); ensureSelectorOpenForEnqueuing(newChannels, nioSocketChannel); wakeup(); @@ -135,7 +132,7 @@ public class SocketSelector extends ESSelector { try { int ops = sk.readyOps(); if ((ops & SelectionKey.OP_CONNECT) != 0) { - attemptConnect(nioSocketChannel); + attemptConnect(nioSocketChannel, true); } if (nioSocketChannel.isConnectComplete()) { @@ -192,23 +189,29 @@ public class SocketSelector extends ESSelector { } private void setupChannel(NioSocketChannel newChannel) { + assert newChannel.getSelector() == this : "The channel must be registered with the selector with which it was created"; try { - if (newChannel.register(this)) { - registeredChannels.add(newChannel); + if (newChannel.isOpen()) { + newChannel.register(); + addRegisteredChannel(newChannel); SelectionKey key = newChannel.getSelectionKey(); key.attach(newChannel); eventHandler.handleRegistration(newChannel); - attemptConnect(newChannel); + attemptConnect(newChannel, false); + } else { + eventHandler.registrationException(newChannel, new ClosedChannelException()); } } catch (Exception e) { eventHandler.registrationException(newChannel, e); } } - private void attemptConnect(NioSocketChannel newChannel) { + private void attemptConnect(NioSocketChannel newChannel, boolean connectEvent) { try { if (newChannel.finishConnect()) { eventHandler.handleConnect(newChannel); + } else if (connectEvent) { + eventHandler.connectException(newChannel, new IOException("Received OP_CONNECT but connect failed")); } } catch (Exception e) { eventHandler.connectException(newChannel, e); diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/AbstractNioChannel.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/AbstractNioChannel.java index 9792f9e64cc..c02312aab51 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/AbstractNioChannel.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/AbstractNioChannel.java @@ -27,7 +27,7 @@ import java.nio.channels.ClosedChannelException; import java.nio.channels.NetworkChannel; import java.nio.channels.SelectableChannel; import java.nio.channels.SelectionKey; -import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicBoolean; /** * This is a basic channel abstraction used by the {@link org.elasticsearch.transport.nio.NioTransport}. @@ -35,40 +35,34 @@ import java.util.concurrent.atomic.AtomicInteger; * A channel is open once it is constructed. The channel remains open and {@link #isOpen()} will return * true until the channel is explicitly closed. *

- * A channel lifecycle has four stages: + * A channel lifecycle has two stages: *

    - *
  1. UNREGISTERED - When a channel is created and prior to it being registered with a selector. - *
  2. REGISTERED - When a channel has been registered with a selector. This is the state of a channel that - * can perform normal operations. - *
  3. CLOSING - When a channel has been marked for closed, but is not yet closed. {@link #isOpen()} will - * still return true. Normal operations should be rejected. The most common scenario for a channel to be - * CLOSING is when channel that was REGISTERED has {@link #closeAsync()} called, but the selector thread - * has not yet closed the channel. - *
  4. CLOSED - The channel has been closed. + *
  5. OPEN - When a channel has been created. This is the state of a channel that can perform normal operations. + *
  6. CLOSED - The channel has been set to closed. All this means is that the channel has been scheduled to be + * closed. The underlying raw channel may not yet be closed. The underlying channel has been closed if the close + * future has been completed. *
* * @param the type of raw channel this AbstractNioChannel uses */ public abstract class AbstractNioChannel implements NioChannel { - static final int UNREGISTERED = 0; - static final int REGISTERED = 1; - static final int CLOSING = 2; - static final int CLOSED = 3; - final S socketChannel; - final AtomicInteger state = new AtomicInteger(UNREGISTERED); + // This indicates if the channel has been scheduled to be closed. Read the closeFuture to determine if + // the channel close process has completed. + final AtomicBoolean isClosing = new AtomicBoolean(false); private final InetSocketAddress localAddress; private final String profile; private final CloseFuture closeFuture = new CloseFuture(); - private volatile ESSelector selector; + private final ESSelector selector; private SelectionKey selectionKey; - public AbstractNioChannel(String profile, S socketChannel) throws IOException { + AbstractNioChannel(String profile, S socketChannel, ESSelector selector) throws IOException { this.profile = profile; this.socketChannel = socketChannel; this.localAddress = (InetSocketAddress) socketChannel.getLocalAddress(); + this.selector = selector; } @Override @@ -89,30 +83,17 @@ public abstract class AbstractNioChannel - * If the current state is UNREGISTERED, the call will attempt to transition the state from UNREGISTERED - * to CLOSING. If this transition is successful, the channel can no longer be registered with an event - * loop and the channel will be synchronously closed in this method call. - *

- * If the channel is REGISTERED and the state can be transitioned to CLOSING, the close operation will + * If the channel is open and the state can be transitioned to closed, the close operation will * be scheduled with the event loop. *

- * If the channel is CLOSING or CLOSED, nothing will be done. + * If the channel is already set to closed, it is assumed that it is already scheduled to be closed. * * @return future that will be complete when the channel is closed */ @Override public CloseFuture closeAsync() { - for (; ; ) { - int state = this.state.get(); - if (state == UNREGISTERED && this.state.compareAndSet(UNREGISTERED, CLOSING)) { - close0(); - break; - } else if (state == REGISTERED && this.state.compareAndSet(REGISTERED, CLOSING)) { - selector.queueChannelClose(this); - break; - } else if (state == CLOSING || state == CLOSED) { - break; - } + if (isClosing.compareAndSet(false, true)) { + selector.queueChannelClose(this); } return closeFuture; } @@ -124,37 +105,32 @@ public abstract class AbstractNioChannel closeListener) throws IOException { + SocketChannel rawChannel = rawChannelFactory.openNioChannel(remoteAddress); + NioSocketChannel channel = new NioSocketChannel(NioChannel.CLIENT, rawChannel, selector); channel.setContexts(new TcpReadContext(channel, handler), new TcpWriteContext(channel)); + channel.getCloseFuture().setListener(closeListener); + scheduleChannel(channel, selector); return channel; } - public NioSocketChannel acceptNioChannel(NioServerSocketChannel serverChannel) throws IOException { - ServerSocketChannel serverSocketChannel = serverChannel.getRawChannel(); - SocketChannel rawChannel = PrivilegedSocketAccess.accept(serverSocketChannel); - configureSocketChannel(rawChannel); - NioSocketChannel channel = new NioSocketChannel(serverChannel.getProfile(), rawChannel); + public NioSocketChannel acceptNioChannel(NioServerSocketChannel serverChannel, SocketSelector selector, + Consumer closeListener) throws IOException { + SocketChannel rawChannel = rawChannelFactory.acceptNioChannel(serverChannel); + NioSocketChannel channel = new NioSocketChannel(serverChannel.getProfile(), rawChannel, selector); channel.setContexts(new TcpReadContext(channel, handler), new TcpWriteContext(channel)); + channel.getCloseFuture().setListener(closeListener); + scheduleChannel(channel, selector); return channel; } - public NioServerSocketChannel openNioServerSocketChannel(String profileName, InetSocketAddress address) + public NioServerSocketChannel openNioServerSocketChannel(String profileName, InetSocketAddress address, AcceptingSelector selector) throws IOException { - ServerSocketChannel socketChannel = ServerSocketChannel.open(); - socketChannel.configureBlocking(false); - ServerSocket socket = socketChannel.socket(); - socket.setReuseAddress(tcpReusedAddress); - socketChannel.bind(address); - return new NioServerSocketChannel(profileName, socketChannel, this); + ServerSocketChannel rawChannel = rawChannelFactory.openNioServerSocketChannel(address); + NioServerSocketChannel serverChannel = new NioServerSocketChannel(profileName, rawChannel, this, selector); + scheduleServerChannel(serverChannel, selector); + return serverChannel; } - private void configureSocketChannel(SocketChannel channel) throws IOException { - channel.configureBlocking(false); - Socket socket = channel.socket(); - socket.setTcpNoDelay(tcpNoDelay); - socket.setKeepAlive(tcpKeepAlive); - socket.setReuseAddress(tcpReusedAddress); - if (tcpSendBufferSize > 0) { - socket.setSendBufferSize(tcpSendBufferSize); + private void scheduleChannel(NioSocketChannel channel, SocketSelector selector) { + try { + selector.scheduleForRegistration(channel); + } catch (IllegalStateException e) { + IOUtils.closeWhileHandlingException(channel.getRawChannel()); + throw e; } - if (tcpReceiveBufferSize > 0) { - socket.setSendBufferSize(tcpReceiveBufferSize); + } + + private void scheduleServerChannel(NioServerSocketChannel channel, AcceptingSelector selector) { + try { + selector.scheduleForRegistration(channel); + } catch (IllegalStateException e) { + IOUtils.closeWhileHandlingException(channel.getRawChannel()); + throw e; + } + } + + static class RawChannelFactory { + + private final boolean tcpNoDelay; + private final boolean tcpKeepAlive; + private final boolean tcpReusedAddress; + private final int tcpSendBufferSize; + private final int tcpReceiveBufferSize; + + RawChannelFactory(TcpTransport.ProfileSettings profileSettings) { + tcpNoDelay = profileSettings.tcpNoDelay; + tcpKeepAlive = profileSettings.tcpKeepAlive; + tcpReusedAddress = profileSettings.reuseAddress; + tcpSendBufferSize = Math.toIntExact(profileSettings.sendBufferSize.getBytes()); + tcpReceiveBufferSize = Math.toIntExact(profileSettings.receiveBufferSize.getBytes()); + } + + SocketChannel openNioChannel(InetSocketAddress remoteAddress) throws IOException { + SocketChannel socketChannel = SocketChannel.open(); + configureSocketChannel(socketChannel); + PrivilegedSocketAccess.connect(socketChannel, remoteAddress); + return socketChannel; + } + + SocketChannel acceptNioChannel(NioServerSocketChannel serverChannel) throws IOException { + ServerSocketChannel serverSocketChannel = serverChannel.getRawChannel(); + SocketChannel socketChannel = PrivilegedSocketAccess.accept(serverSocketChannel); + configureSocketChannel(socketChannel); + return socketChannel; + } + + ServerSocketChannel openNioServerSocketChannel(InetSocketAddress address) throws IOException { + ServerSocketChannel serverSocketChannel = ServerSocketChannel.open(); + serverSocketChannel.configureBlocking(false); + ServerSocket socket = serverSocketChannel.socket(); + socket.setReuseAddress(tcpReusedAddress); + serverSocketChannel.bind(address); + return serverSocketChannel; + } + + private void configureSocketChannel(SocketChannel channel) throws IOException { + channel.configureBlocking(false); + Socket socket = channel.socket(); + socket.setTcpNoDelay(tcpNoDelay); + socket.setKeepAlive(tcpKeepAlive); + socket.setReuseAddress(tcpReusedAddress); + if (tcpSendBufferSize > 0) { + socket.setSendBufferSize(tcpSendBufferSize); + } + if (tcpReceiveBufferSize > 0) { + socket.setSendBufferSize(tcpReceiveBufferSize); + } } } } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/CloseFuture.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/CloseFuture.java index e41632174ac..c27ba306e0e 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/CloseFuture.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/CloseFuture.java @@ -80,7 +80,7 @@ public class CloseFuture extends BaseFuture { this.listener.set(listener); } - void channelClosed(NioChannel channel) { + boolean channelClosed(NioChannel channel) { boolean set = set(channel); if (set) { Consumer listener = this.listener.get(); @@ -88,10 +88,11 @@ public class CloseFuture extends BaseFuture { listener.accept(channel); } } + return set; } - void channelCloseThrewException(NioChannel channel, IOException ex) { + boolean channelCloseThrewException(NioChannel channel, IOException ex) { boolean set = setException(ex); if (set) { Consumer listener = this.listener.get(); @@ -99,6 +100,7 @@ public class CloseFuture extends BaseFuture { listener.accept(channel); } } + return set; } } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioChannel.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioChannel.java index 281e296391c..c4133cce271 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioChannel.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioChannel.java @@ -40,7 +40,7 @@ public interface NioChannel { void closeFromSelector(); - boolean register(ESSelector selector) throws ClosedChannelException; + void register() throws ClosedChannelException; ESSelector getSelector(); diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannel.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannel.java index bc8d423a45d..a0524064cae 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannel.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannel.java @@ -19,6 +19,8 @@ package org.elasticsearch.transport.nio.channel; +import org.elasticsearch.transport.nio.AcceptingSelector; + import java.io.IOException; import java.nio.channels.ServerSocketChannel; @@ -26,8 +28,9 @@ public class NioServerSocketChannel extends AbstractNioChannel { private final InetSocketAddress remoteAddress; private final ConnectFuture connectFuture = new ConnectFuture(); - private volatile SocketSelector socketSelector; + private final SocketSelector socketSelector; private WriteContext writeContext; private ReadContext readContext; - public NioSocketChannel(String profile, SocketChannel socketChannel) throws IOException { - super(profile, socketChannel); + public NioSocketChannel(String profile, SocketChannel socketChannel, SocketSelector selector) throws IOException { + super(profile, socketChannel, selector); this.remoteAddress = (InetSocketAddress) socketChannel.getRemoteAddress(); - } - - @Override - public CloseFuture closeAsync() { - clearQueuedWrites(); - - return super.closeAsync(); + this.socketSelector = selector; } @Override public void closeFromSelector() { + assert socketSelector.isOnCurrentThread() : "Should only call from selector thread"; // Even if the channel has already been closed we will clear any pending write operations just in case - clearQueuedWrites(); + if (writeContext.hasQueuedWriteOps()) { + writeContext.clearQueuedWriteOps(new ClosedChannelException()); + } super.closeFromSelector(); } @@ -63,12 +59,6 @@ public class NioSocketChannel extends AbstractNioChannel { return socketSelector; } - @Override - boolean markRegistered(ESSelector selector) { - this.socketSelector = (SocketSelector) selector; - return super.markRegistered(selector); - } - public int write(NetworkBytesReference[] references) throws IOException { int written; if (references.length == 1) { @@ -122,11 +112,11 @@ public class NioSocketChannel extends AbstractNioChannel { } public boolean isWritable() { - return state.get() == REGISTERED; + return isClosing.get() == false; } public boolean isReadable() { - return state.get() == REGISTERED; + return isClosing.get() == false; } /** @@ -176,14 +166,4 @@ public class NioSocketChannel extends AbstractNioChannel { throw e; } } - - private void clearQueuedWrites() { - // Even if the channel has already been closed we will clear any pending write operations just in case - if (state.get() > UNREGISTERED) { - SocketSelector selector = getSelector(); - if (selector != null && selector.isOnCurrentThread() && writeContext.hasQueuedWriteOps()) { - writeContext.clearQueuedWriteOps(new ClosedChannelException()); - } - } - } } diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/AcceptingSelectorTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/AcceptingSelectorTests.java index e3cf9b0a7e9..05d3b292b0a 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/AcceptingSelectorTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/AcceptingSelectorTests.java @@ -26,12 +26,15 @@ import org.elasticsearch.transport.nio.utils.TestSelectionKey; import org.junit.Before; import java.io.IOException; +import java.nio.channels.ClosedChannelException; import java.nio.channels.SelectionKey; import java.nio.channels.Selector; import java.security.PrivilegedActionException; import java.util.HashSet; import java.util.Set; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.same; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -59,14 +62,14 @@ public class AcceptingSelectorTests extends ESTestCase { selectionKey = new TestSelectionKey(0); selectionKey.attach(serverChannel); when(serverChannel.getSelectionKey()).thenReturn(selectionKey); + when(serverChannel.getSelector()).thenReturn(selector); + when(serverChannel.isOpen()).thenReturn(true); when(rawSelector.selectedKeys()).thenReturn(keySet); when(rawSelector.select(0)).thenReturn(1); } public void testRegisteredChannel() throws IOException, PrivilegedActionException { - selector.registerServerChannel(serverChannel); - - when(serverChannel.register(selector)).thenReturn(true); + selector.scheduleForRegistration(serverChannel); selector.doSelect(0); @@ -76,6 +79,34 @@ public class AcceptingSelectorTests extends ESTestCase { assertTrue(registeredChannels.contains(serverChannel)); } + public void testClosedChannelWillNotBeRegistered() throws Exception { + when(serverChannel.isOpen()).thenReturn(false); + selector.scheduleForRegistration(serverChannel); + + selector.doSelect(0); + + verify(eventHandler).registrationException(same(serverChannel), any(ClosedChannelException.class)); + + Set registeredChannels = selector.getRegisteredChannels(); + assertEquals(0, registeredChannels.size()); + assertFalse(registeredChannels.contains(serverChannel)); + } + + public void testRegisterChannelFailsDueToException() throws Exception { + selector.scheduleForRegistration(serverChannel); + + ClosedChannelException closedChannelException = new ClosedChannelException(); + doThrow(closedChannelException).when(serverChannel).register(); + + selector.doSelect(0); + + verify(eventHandler).registrationException(serverChannel, closedChannelException); + + Set registeredChannels = selector.getRegisteredChannels(); + assertEquals(0, registeredChannels.size()); + assertFalse(registeredChannels.contains(serverChannel)); + } + public void testAcceptEvent() throws IOException { selectionKey.setReadyOps(SelectionKey.OP_ACCEPT); keySet.add(selectionKey); @@ -98,15 +129,13 @@ public class AcceptingSelectorTests extends ESTestCase { } public void testCleanup() throws IOException { - selector.registerServerChannel(serverChannel); - - when(serverChannel.register(selector)).thenReturn(true); + selector.scheduleForRegistration(serverChannel); selector.doSelect(0); assertEquals(1, selector.getRegisteredChannels().size()); - selector.cleanup(); + selector.cleanupAndCloseChannels(); verify(eventHandler).handleClose(serverChannel); } diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/AcceptorEventHandlerTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/AcceptorEventHandlerTests.java index fc6829d5948..8ae6559c741 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/AcceptorEventHandlerTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/AcceptorEventHandlerTests.java @@ -19,21 +19,29 @@ package org.elasticsearch.transport.nio; +import org.apache.lucene.util.IOUtils; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.transport.nio.channel.ChannelFactory; import org.elasticsearch.transport.nio.channel.DoNotRegisterServerChannel; +import org.elasticsearch.transport.nio.channel.NioChannel; import org.elasticsearch.transport.nio.channel.NioServerSocketChannel; import org.elasticsearch.transport.nio.channel.NioSocketChannel; +import org.elasticsearch.transport.nio.channel.ReadContext; +import org.elasticsearch.transport.nio.channel.WriteContext; import org.junit.Before; +import org.mockito.ArgumentCaptor; import java.io.IOException; import java.nio.channels.SelectionKey; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; import java.util.ArrayList; -import java.util.Arrays; +import java.util.Collections; import java.util.HashSet; +import java.util.function.Consumer; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.same; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -55,8 +63,9 @@ public class AcceptorEventHandlerTests extends ESTestCase { selectors.add(socketSelector); handler = new AcceptorEventHandler(logger, openChannels, new RoundRobinSelectorSupplier(selectors)); - channel = new DoNotRegisterServerChannel("", mock(ServerSocketChannel.class), channelFactory); - channel.register(mock(ESSelector.class)); + AcceptingSelector selector = mock(AcceptingSelector.class); + channel = new DoNotRegisterServerChannel("", mock(ServerSocketChannel.class), channelFactory, selector); + channel.register(); } public void testHandleRegisterAdjustsOpenChannels() { @@ -75,25 +84,34 @@ public class AcceptorEventHandlerTests extends ESTestCase { assertEquals(SelectionKey.OP_ACCEPT, channel.getSelectionKey().interestOps()); } - public void testHandleAcceptRegistersWithSelector() throws IOException { - NioSocketChannel childChannel = new NioSocketChannel("", mock(SocketChannel.class)); - when(channelFactory.acceptNioChannel(channel)).thenReturn(childChannel); + public void testHandleAcceptCallsChannelFactory() throws IOException { + NioSocketChannel childChannel = new NioSocketChannel("", mock(SocketChannel.class), socketSelector); + when(channelFactory.acceptNioChannel(same(channel), same(socketSelector), any())).thenReturn(childChannel); handler.acceptChannel(channel); - verify(socketSelector).registerSocketChannel(childChannel); + verify(channelFactory).acceptNioChannel(same(channel), same(socketSelector), any()); + } + @SuppressWarnings("unchecked") public void testHandleAcceptAddsToOpenChannelsAndAddsCloseListenerToRemove() throws IOException { - NioSocketChannel childChannel = new NioSocketChannel("", SocketChannel.open()); - when(channelFactory.acceptNioChannel(channel)).thenReturn(childChannel); + SocketChannel rawChannel = SocketChannel.open(); + NioSocketChannel childChannel = new NioSocketChannel("", rawChannel, socketSelector); + childChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class)); + when(channelFactory.acceptNioChannel(same(channel), same(socketSelector), any())).thenReturn(childChannel); handler.acceptChannel(channel); + Class> clazz = (Class>)(Class)Consumer.class; + ArgumentCaptor> listener = ArgumentCaptor.forClass(clazz); + verify(channelFactory).acceptNioChannel(same(channel), same(socketSelector), listener.capture()); - assertEquals(new HashSet<>(Arrays.asList(childChannel)), openChannels.getAcceptedChannels()); + assertEquals(new HashSet<>(Collections.singletonList(childChannel)), openChannels.getAcceptedChannels()); - childChannel.closeAsync(); + listener.getValue().accept(childChannel); assertEquals(new HashSet<>(), openChannels.getAcceptedChannels()); + + IOUtils.closeWhileHandlingException(rawChannel); } } diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/ESSelectorTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/ESSelectorTests.java index e57b1bc4efd..53705fcf521 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/ESSelectorTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/ESSelectorTests.java @@ -28,6 +28,7 @@ import java.nio.channels.ClosedSelectorException; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; public class ESSelectorTests extends ESTestCase { @@ -43,7 +44,8 @@ public class ESSelectorTests extends ESTestCase { public void testQueueChannelForClosed() throws IOException { NioChannel channel = mock(NioChannel.class); - selector.registeredChannels.add(channel); + when(channel.getSelector()).thenReturn(selector); + selector.addRegisteredChannel(channel); selector.queueChannelClose(channel); @@ -52,6 +54,8 @@ public class ESSelectorTests extends ESTestCase { selector.singleLoop(); verify(handler).handleClose(channel); + // Will be called in the channel close method + selector.removeRegisteredChannel(channel); assertEquals(0, selector.getRegisteredChannels().size()); } diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/NioClientTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/NioClientTests.java index e9f6dfe7f71..4cae51acc83 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/NioClientTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/NioClientTests.java @@ -75,26 +75,17 @@ public class NioClientTests extends ESTestCase { public void testCreateConnections() throws IOException, InterruptedException { NioSocketChannel channel1 = mock(NioSocketChannel.class); ConnectFuture connectFuture1 = mock(ConnectFuture.class); - CloseFuture closeFuture1 = mock(CloseFuture.class); NioSocketChannel channel2 = mock(NioSocketChannel.class); ConnectFuture connectFuture2 = mock(ConnectFuture.class); - CloseFuture closeFuture2 = mock(CloseFuture.class); - when(channelFactory.openNioChannel(address.address())).thenReturn(channel1, channel2); - when(channel1.getCloseFuture()).thenReturn(closeFuture1); + when(channelFactory.openNioChannel(address.address(), selector, listener)).thenReturn(channel1, channel2); when(channel1.getConnectFuture()).thenReturn(connectFuture1); - when(channel2.getCloseFuture()).thenReturn(closeFuture2); when(channel2.getConnectFuture()).thenReturn(connectFuture2); when(connectFuture1.awaitConnectionComplete(5, TimeUnit.MILLISECONDS)).thenReturn(true); when(connectFuture2.awaitConnectionComplete(5, TimeUnit.MILLISECONDS)).thenReturn(true); client.connectToChannels(node, channels, TimeValue.timeValueMillis(5), listener); - verify(closeFuture1).setListener(listener); - verify(closeFuture2).setListener(listener); - verify(selector).registerSocketChannel(channel1); - verify(selector).registerSocketChannel(channel2); - assertEquals(channel1, channels[0]); assertEquals(channel2, channels[1]); } @@ -102,19 +93,14 @@ public class NioClientTests extends ESTestCase { public void testWithADifferentConnectTimeout() throws IOException, InterruptedException { NioSocketChannel channel1 = mock(NioSocketChannel.class); ConnectFuture connectFuture1 = mock(ConnectFuture.class); - CloseFuture closeFuture1 = mock(CloseFuture.class); - when(channelFactory.openNioChannel(address.address())).thenReturn(channel1); - when(channel1.getCloseFuture()).thenReturn(closeFuture1); + when(channelFactory.openNioChannel(address.address(), selector, listener)).thenReturn(channel1); when(channel1.getConnectFuture()).thenReturn(connectFuture1); when(connectFuture1.awaitConnectionComplete(3, TimeUnit.MILLISECONDS)).thenReturn(true); channels = new NioSocketChannel[1]; client.connectToChannels(node, channels, TimeValue.timeValueMillis(3), listener); - verify(closeFuture1).setListener(listener); - verify(selector).registerSocketChannel(channel1); - assertEquals(channel1, channels[0]); } @@ -126,7 +112,7 @@ public class NioClientTests extends ESTestCase { ConnectFuture connectFuture2 = mock(ConnectFuture.class); CloseFuture closeFuture2 = mock(CloseFuture.class); - when(channelFactory.openNioChannel(address.address())).thenReturn(channel1, channel2); + when(channelFactory.openNioChannel(address.address(), selector, listener)).thenReturn(channel1, channel2); when(channel1.getCloseFuture()).thenReturn(closeFuture1); when(channel1.getConnectFuture()).thenReturn(connectFuture1); when(channel2.getCloseFuture()).thenReturn(closeFuture2); @@ -151,16 +137,12 @@ public class NioClientTests extends ESTestCase { public void testConnectionException() throws IOException, InterruptedException { NioSocketChannel channel1 = mock(NioSocketChannel.class); ConnectFuture connectFuture1 = mock(ConnectFuture.class); - CloseFuture closeFuture1 = mock(CloseFuture.class); NioSocketChannel channel2 = mock(NioSocketChannel.class); ConnectFuture connectFuture2 = mock(ConnectFuture.class); - CloseFuture closeFuture2 = mock(CloseFuture.class); IOException ioException = new IOException(); - when(channelFactory.openNioChannel(address.address())).thenReturn(channel1, channel2); - when(channel1.getCloseFuture()).thenReturn(closeFuture1); + when(channelFactory.openNioChannel(address.address(), selector, listener)).thenReturn(channel1, channel2); when(channel1.getConnectFuture()).thenReturn(connectFuture1); - when(channel2.getCloseFuture()).thenReturn(closeFuture2); when(channel2.getConnectFuture()).thenReturn(connectFuture2); when(connectFuture1.awaitConnectionComplete(5, TimeUnit.MILLISECONDS)).thenReturn(true); when(connectFuture2.awaitConnectionComplete(5, TimeUnit.MILLISECONDS)).thenReturn(false); diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java index 8e16a040b74..2ba2e4cc02a 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java @@ -88,7 +88,9 @@ public class SimpleNioTransportTests extends AbstractSimpleTransportTestCase { @Override protected MockTransportService build(Settings settings, Version version, ClusterSettings clusterSettings, boolean doHandshake) { - settings = Settings.builder().put(settings).put(TcpTransport.PORT.getKey(), "0").build(); + settings = Settings.builder().put(settings) + .put(TcpTransport.PORT.getKey(), "0") + .build(); MockTransportService transportService = nioFromThreadPool(settings, threadPool, version, clusterSettings, doHandshake); transportService.start(); return transportService; diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketEventHandlerTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketEventHandlerTests.java index 393b9dc7cc5..3bc5cd083a6 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketEventHandlerTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketEventHandlerTests.java @@ -58,12 +58,12 @@ public class SocketEventHandlerTests extends ESTestCase { SocketSelector socketSelector = mock(SocketSelector.class); handler = new SocketEventHandler(logger, exceptionHandler); rawChannel = mock(SocketChannel.class); - channel = new DoNotRegisterChannel("", rawChannel); + channel = new DoNotRegisterChannel("", rawChannel, socketSelector); readContext = mock(ReadContext.class); when(rawChannel.finishConnect()).thenReturn(true); channel.setContexts(readContext, new TcpWriteContext(channel)); - channel.register(socketSelector); + channel.register(); channel.finishConnect(); when(socketSelector.isOnCurrentThread()).thenReturn(true); diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketSelectorTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketSelectorTests.java index 050cf856442..50ce4a55b29 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketSelectorTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketSelectorTests.java @@ -39,6 +39,7 @@ import java.util.Set; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyInt; +import static org.mockito.Matchers.same; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -73,15 +74,15 @@ public class SocketSelectorTests extends ESTestCase { when(rawSelector.selectedKeys()).thenReturn(keySet); when(rawSelector.select(0)).thenReturn(1); + when(channel.isOpen()).thenReturn(true); when(channel.getSelectionKey()).thenReturn(selectionKey); when(channel.getWriteContext()).thenReturn(writeContext); when(channel.isConnectComplete()).thenReturn(true); + when(channel.getSelector()).thenReturn(socketSelector); } public void testRegisterChannel() throws Exception { - socketSelector.registerSocketChannel(channel); - - when(channel.register(socketSelector)).thenReturn(true); + socketSelector.scheduleForRegistration(channel); socketSelector.doSelect(0); @@ -92,13 +93,13 @@ public class SocketSelectorTests extends ESTestCase { assertTrue(registeredChannels.contains(channel)); } - public void testRegisterChannelFails() throws Exception { - socketSelector.registerSocketChannel(channel); - - when(channel.register(socketSelector)).thenReturn(false); + public void testClosedChannelWillNotBeRegistered() throws Exception { + when(channel.isOpen()).thenReturn(false); + socketSelector.scheduleForRegistration(channel); socketSelector.doSelect(0); + verify(eventHandler).registrationException(same(channel), any(ClosedChannelException.class)); verify(channel, times(0)).finishConnect(); Set registeredChannels = socketSelector.getRegisteredChannels(); @@ -107,10 +108,10 @@ public class SocketSelectorTests extends ESTestCase { } public void testRegisterChannelFailsDueToException() throws Exception { - socketSelector.registerSocketChannel(channel); + socketSelector.scheduleForRegistration(channel); ClosedChannelException closedChannelException = new ClosedChannelException(); - when(channel.register(socketSelector)).thenThrow(closedChannelException); + doThrow(closedChannelException).when(channel).register(); socketSelector.doSelect(0); @@ -123,9 +124,8 @@ public class SocketSelectorTests extends ESTestCase { } public void testSuccessfullyRegisterChannelWillConnect() throws Exception { - socketSelector.registerSocketChannel(channel); + socketSelector.scheduleForRegistration(channel); - when(channel.register(socketSelector)).thenReturn(true); when(channel.finishConnect()).thenReturn(true); socketSelector.doSelect(0); @@ -134,9 +134,8 @@ public class SocketSelectorTests extends ESTestCase { } public void testConnectIncompleteWillNotNotify() throws Exception { - socketSelector.registerSocketChannel(channel); + socketSelector.scheduleForRegistration(channel); - when(channel.register(socketSelector)).thenReturn(true); when(channel.finishConnect()).thenReturn(false); socketSelector.doSelect(0); @@ -145,7 +144,7 @@ public class SocketSelectorTests extends ESTestCase { } public void testQueueWriteWhenNotRunning() throws Exception { - socketSelector.close(false); + socketSelector.close(); socketSelector.queueWrite(new WriteOperation(channel, bufferReference, listener)); @@ -318,16 +317,15 @@ public class SocketSelectorTests extends ESTestCase { public void testCleanup() throws Exception { NioSocketChannel unRegisteredChannel = mock(NioSocketChannel.class); - when(channel.register(socketSelector)).thenReturn(true); - socketSelector.registerSocketChannel(channel); + socketSelector.scheduleForRegistration(channel); socketSelector.doSelect(0); NetworkBytesReference networkBuffer = NetworkBytesReference.wrap(new BytesArray(new byte[1])); socketSelector.queueWrite(new WriteOperation(mock(NioSocketChannel.class), networkBuffer, listener)); - socketSelector.registerSocketChannel(unRegisteredChannel); + socketSelector.scheduleForRegistration(unRegisteredChannel); - socketSelector.cleanup(); + socketSelector.cleanupAndCloseChannels(); verify(listener).onFailure(any(ClosedSelectorException.class)); verify(eventHandler).handleClose(channel); diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/AbstractNioChannelTestCase.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/AbstractNioChannelTestCase.java deleted file mode 100644 index 7db9f48ca45..00000000000 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/AbstractNioChannelTestCase.java +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.transport.nio.channel; - -import org.elasticsearch.common.CheckedRunnable; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.mocksocket.MockServerSocket; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.transport.TcpTransport; -import org.elasticsearch.transport.nio.TcpReadHandler; -import org.junit.After; -import org.junit.Before; - -import java.io.IOException; -import java.io.InputStream; -import java.net.Socket; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; -import java.util.concurrent.atomic.AtomicReference; - -import static org.mockito.Mockito.mock; - -public abstract class AbstractNioChannelTestCase extends ESTestCase { - - ChannelFactory channelFactory = new ChannelFactory(new TcpTransport.ProfileSettings(Settings.EMPTY, "default"), - mock(TcpReadHandler.class)); - MockServerSocket mockServerSocket; - private Thread serverThread; - - @Before - public void serverSocketSetup() throws IOException { - mockServerSocket = new MockServerSocket(0); - serverThread = new Thread(() -> { - while (!mockServerSocket.isClosed()) { - try { - Socket socket = mockServerSocket.accept(); - InputStream inputStream = socket.getInputStream(); - socket.close(); - } catch (IOException e) { - } - } - }); - serverThread.start(); - } - - @After - public void serverSocketTearDown() throws IOException { - serverThread.interrupt(); - mockServerSocket.close(); - } - - public abstract NioChannel channelToClose() throws IOException; - - public void testClose() throws IOException, TimeoutException, InterruptedException { - AtomicReference ref = new AtomicReference<>(); - CountDownLatch latch = new CountDownLatch(1); - - NioChannel socketChannel = channelToClose(); - CloseFuture closeFuture = socketChannel.getCloseFuture(); - closeFuture.setListener((c) -> {ref.set(c); latch.countDown();}); - - assertFalse(closeFuture.isClosed()); - assertTrue(socketChannel.getRawChannel().isOpen()); - - socketChannel.closeAsync(); - - closeFuture.awaitClose(100, TimeUnit.SECONDS); - - assertFalse(socketChannel.getRawChannel().isOpen()); - assertTrue(closeFuture.isClosed()); - latch.await(); - assertSame(socketChannel, ref.get()); - } - - protected Runnable wrappedRunnable(CheckedRunnable runnable) { - return () -> { - try { - runnable.run(); - } catch (Exception e) { - } - }; - } -} diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/ChannelFactoryTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/ChannelFactoryTests.java new file mode 100644 index 00000000000..8851c37f201 --- /dev/null +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/ChannelFactoryTests.java @@ -0,0 +1,150 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio.channel; + +import org.apache.lucene.util.IOUtils; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.transport.nio.AcceptingSelector; +import org.elasticsearch.transport.nio.SocketSelector; +import org.elasticsearch.transport.nio.TcpReadHandler; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.util.function.Consumer; + +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.same; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class ChannelFactoryTests extends ESTestCase { + + private ChannelFactory channelFactory; + private ChannelFactory.RawChannelFactory rawChannelFactory; + private Consumer listener; + private SocketChannel rawChannel; + private ServerSocketChannel rawServerChannel; + private SocketSelector socketSelector; + private AcceptingSelector acceptingSelector; + + @Before + @SuppressWarnings("unchecked") + public void setupFactory() throws IOException { + rawChannelFactory = mock(ChannelFactory.RawChannelFactory.class); + channelFactory = new ChannelFactory(rawChannelFactory, mock(TcpReadHandler.class)); + listener = mock(Consumer.class); + socketSelector = mock(SocketSelector.class); + acceptingSelector = mock(AcceptingSelector.class); + rawChannel = SocketChannel.open(); + rawServerChannel = ServerSocketChannel.open(); + } + + @After + public void ensureClosed() throws IOException { + IOUtils.closeWhileHandlingException(rawChannel); + IOUtils.closeWhileHandlingException(rawServerChannel); + } + + public void testAcceptChannel() throws IOException { + NioServerSocketChannel serverChannel = mock(NioServerSocketChannel.class); + when(rawChannelFactory.acceptNioChannel(serverChannel)).thenReturn(rawChannel); + when(serverChannel.getProfile()).thenReturn("parent-profile"); + + NioSocketChannel channel = channelFactory.acceptNioChannel(serverChannel, socketSelector, listener); + + verify(socketSelector).scheduleForRegistration(channel); + + assertEquals(socketSelector, channel.getSelector()); + assertEquals("parent-profile", channel.getProfile()); + assertEquals(rawChannel, channel.getRawChannel()); + + channel.getCloseFuture().channelClosed(channel); + + verify(listener).accept(channel); + } + + public void testAcceptedChannelRejected() throws IOException { + NioServerSocketChannel serverChannel = mock(NioServerSocketChannel.class); + when(rawChannelFactory.acceptNioChannel(serverChannel)).thenReturn(rawChannel); + doThrow(new IllegalStateException()).when(socketSelector).scheduleForRegistration(any()); + + expectThrows(IllegalStateException.class, () -> channelFactory.acceptNioChannel(serverChannel, socketSelector, listener)); + + assertFalse(rawChannel.isOpen()); + } + + public void testOpenChannel() throws IOException { + InetSocketAddress address = mock(InetSocketAddress.class); + when(rawChannelFactory.openNioChannel(same(address))).thenReturn(rawChannel); + + NioSocketChannel channel = channelFactory.openNioChannel(address, socketSelector, listener); + + verify(socketSelector).scheduleForRegistration(channel); + + assertEquals(socketSelector, channel.getSelector()); + assertEquals("client-socket", channel.getProfile()); + assertEquals(rawChannel, channel.getRawChannel()); + + channel.getCloseFuture().channelClosed(channel); + + verify(listener).accept(channel); + } + + public void testOpenedChannelRejected() throws IOException { + InetSocketAddress address = mock(InetSocketAddress.class); + when(rawChannelFactory.openNioChannel(same(address))).thenReturn(rawChannel); + doThrow(new IllegalStateException()).when(socketSelector).scheduleForRegistration(any()); + + expectThrows(IllegalStateException.class, () -> channelFactory.openNioChannel(address, socketSelector, listener)); + + assertFalse(rawChannel.isOpen()); + } + + public void testOpenServerChannel() throws IOException { + InetSocketAddress address = mock(InetSocketAddress.class); + when(rawChannelFactory.openNioServerSocketChannel(same(address))).thenReturn(rawServerChannel); + + String profile = "profile"; + NioServerSocketChannel channel = channelFactory.openNioServerSocketChannel(profile, address, acceptingSelector); + + verify(acceptingSelector).scheduleForRegistration(channel); + + assertEquals(acceptingSelector, channel.getSelector()); + assertEquals(profile, channel.getProfile()); + assertEquals(rawServerChannel, channel.getRawChannel()); + } + + public void testOpenedServerChannelRejected() throws IOException { + InetSocketAddress address = mock(InetSocketAddress.class); + when(rawChannelFactory.openNioServerSocketChannel(same(address))).thenReturn(rawServerChannel); + doThrow(new IllegalStateException()).when(acceptingSelector).scheduleForRegistration(any()); + + expectThrows(IllegalStateException.class, () -> channelFactory.openNioServerSocketChannel("", address, acceptingSelector)); + + assertFalse(rawServerChannel.isOpen()); + } +} diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/DoNotRegisterChannel.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/DoNotRegisterChannel.java index 38f381bfcc5..70496da8a49 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/DoNotRegisterChannel.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/DoNotRegisterChannel.java @@ -19,7 +19,7 @@ package org.elasticsearch.transport.nio.channel; -import org.elasticsearch.transport.nio.ESSelector; +import org.elasticsearch.transport.nio.SocketSelector; import org.elasticsearch.transport.nio.utils.TestSelectionKey; import java.io.IOException; @@ -28,17 +28,12 @@ import java.nio.channels.SocketChannel; public class DoNotRegisterChannel extends NioSocketChannel { - public DoNotRegisterChannel(String profile, SocketChannel socketChannel) throws IOException { - super(profile, socketChannel); + public DoNotRegisterChannel(String profile, SocketChannel socketChannel, SocketSelector selector) throws IOException { + super(profile, socketChannel, selector); } @Override - public boolean register(ESSelector selector) throws ClosedChannelException { - if (markRegistered(selector)) { - setSelectionKey(new TestSelectionKey(0)); - return true; - } else { - return false; - } + public void register() throws ClosedChannelException { + setSelectionKey(new TestSelectionKey(0)); } } diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/DoNotRegisterServerChannel.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/DoNotRegisterServerChannel.java index e9e1fc207a0..783bd6fc5fa 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/DoNotRegisterServerChannel.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/DoNotRegisterServerChannel.java @@ -19,6 +19,7 @@ package org.elasticsearch.transport.nio.channel; +import org.elasticsearch.transport.nio.AcceptingSelector; import org.elasticsearch.transport.nio.ESSelector; import org.elasticsearch.transport.nio.utils.TestSelectionKey; @@ -28,17 +29,13 @@ import java.nio.channels.ServerSocketChannel; public class DoNotRegisterServerChannel extends NioServerSocketChannel { - public DoNotRegisterServerChannel(String profile, ServerSocketChannel channel, ChannelFactory channelFactory) throws IOException { - super(profile, channel, channelFactory); + public DoNotRegisterServerChannel(String profile, ServerSocketChannel channel, ChannelFactory channelFactory, + AcceptingSelector selector) throws IOException { + super(profile, channel, channelFactory, selector); } @Override - public boolean register(ESSelector selector) throws ClosedChannelException { - if (markRegistered(selector)) { - setSelectionKey(new TestSelectionKey(0)); - return true; - } else { - return false; - } + public void register() throws ClosedChannelException { + setSelectionKey(new TestSelectionKey(0)); } } diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannelTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannelTests.java index c991263562c..6f05d3c1f34 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannelTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannelTests.java @@ -19,15 +19,81 @@ package org.elasticsearch.transport.nio.channel; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.transport.nio.AcceptingSelector; +import org.elasticsearch.transport.nio.AcceptorEventHandler; +import org.elasticsearch.transport.nio.OpenChannels; +import org.junit.After; +import org.junit.Before; + import java.io.IOException; -import java.net.InetAddress; -import java.net.InetSocketAddress; +import java.nio.channels.ServerSocketChannel; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; -public class NioServerSocketChannelTests extends AbstractNioChannelTestCase { +import static org.mockito.Mockito.mock; - @Override - public NioChannel channelToClose() throws IOException { - return channelFactory.openNioServerSocketChannel("nio", new InetSocketAddress(InetAddress.getLoopbackAddress(),0)); +public class NioServerSocketChannelTests extends ESTestCase { + + private AcceptingSelector selector; + private AtomicBoolean closedRawChannel; + private Thread thread; + + @Before + @SuppressWarnings("unchecked") + public void setSelector() throws IOException { + selector = new AcceptingSelector(new AcceptorEventHandler(logger, mock(OpenChannels.class), mock(Supplier.class))); + thread = new Thread(selector::runLoop); + closedRawChannel = new AtomicBoolean(false); + thread.start(); + selector.isRunningFuture().actionGet(); } + @After + public void stopSelector() throws IOException, InterruptedException { + selector.close(); + thread.join(); + } + + public void testClose() throws IOException, TimeoutException, InterruptedException { + AtomicReference ref = new AtomicReference<>(); + CountDownLatch latch = new CountDownLatch(1); + + NioChannel channel = new DoNotCloseServerChannel("nio", mock(ServerSocketChannel.class), mock(ChannelFactory.class), selector); + channel.getCloseFuture().setListener((c) -> { + ref.set(c); + latch.countDown(); + }); + + CloseFuture closeFuture = channel.getCloseFuture(); + + assertFalse(closeFuture.isClosed()); + assertFalse(closedRawChannel.get()); + + channel.closeAsync(); + + closeFuture.awaitClose(100, TimeUnit.SECONDS); + + assertTrue(closedRawChannel.get()); + assertTrue(closeFuture.isClosed()); + latch.await(); + assertSame(channel, ref.get()); + } + + private class DoNotCloseServerChannel extends DoNotRegisterServerChannel { + + private DoNotCloseServerChannel(String profile, ServerSocketChannel channel, ChannelFactory channelFactory, + AcceptingSelector selector) throws IOException { + super(profile, channel, channelFactory, selector); + } + + @Override + void closeRawChannel() throws IOException { + closedRawChannel.set(true); + } + } } diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioSocketChannelTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioSocketChannelTests.java index 95b858dd5bc..3d039b41a8a 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioSocketChannelTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioSocketChannelTests.java @@ -19,65 +19,116 @@ package org.elasticsearch.transport.nio.channel; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.transport.nio.SocketEventHandler; +import org.elasticsearch.transport.nio.SocketSelector; +import org.junit.After; +import org.junit.Before; + import java.io.IOException; import java.net.ConnectException; -import java.net.InetAddress; -import java.net.InetSocketAddress; +import java.nio.channels.SocketChannel; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; -import java.util.concurrent.locks.LockSupport; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; -public class NioSocketChannelTests extends AbstractNioChannelTestCase { +public class NioSocketChannelTests extends ESTestCase { - private InetAddress loopbackAddress = InetAddress.getLoopbackAddress(); + private SocketSelector selector; + private AtomicBoolean closedRawChannel; + private Thread thread; - @Override - public NioChannel channelToClose() throws IOException { - return channelFactory.openNioChannel(new InetSocketAddress(loopbackAddress, mockServerSocket.getLocalPort())); + @Before + @SuppressWarnings("unchecked") + public void startSelector() throws IOException { + selector = new SocketSelector(new SocketEventHandler(logger, mock(BiConsumer.class))); + thread = new Thread(selector::runLoop); + closedRawChannel = new AtomicBoolean(false); + thread.start(); + selector.isRunningFuture().actionGet(); } - public void testConnectSucceeds() throws IOException, InterruptedException { - InetSocketAddress remoteAddress = new InetSocketAddress(loopbackAddress, mockServerSocket.getLocalPort()); - NioSocketChannel socketChannel = channelFactory.openNioChannel(remoteAddress); - Thread thread = new Thread(wrappedRunnable(() -> ensureConnect(socketChannel))); - thread.start(); - ConnectFuture connectFuture = socketChannel.getConnectFuture(); - connectFuture.awaitConnectionComplete(100, TimeUnit.SECONDS); - - assertTrue(socketChannel.isConnectComplete()); - assertTrue(socketChannel.isOpen()); - assertFalse(connectFuture.connectFailed()); - assertNull(connectFuture.getException()); - + @After + public void stopSelector() throws IOException, InterruptedException { + selector.close(); thread.join(); } - public void testConnectFails() throws IOException, InterruptedException { - mockServerSocket.close(); - InetSocketAddress remoteAddress = new InetSocketAddress(loopbackAddress, mockServerSocket.getLocalPort()); - NioSocketChannel socketChannel = channelFactory.openNioChannel(remoteAddress); - Thread thread = new Thread(wrappedRunnable(() -> ensureConnect(socketChannel))); - thread.start(); + public void testClose() throws IOException, TimeoutException, InterruptedException { + AtomicReference ref = new AtomicReference<>(); + CountDownLatch latch = new CountDownLatch(1); + + NioSocketChannel socketChannel = new DoNotCloseChannel(NioChannel.CLIENT, mock(SocketChannel.class), selector); + socketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class)); + socketChannel.getCloseFuture().setListener((c) -> { + ref.set(c); + latch.countDown(); + }); + CloseFuture closeFuture = socketChannel.getCloseFuture(); + + assertFalse(closeFuture.isClosed()); + assertFalse(closedRawChannel.get()); + + socketChannel.closeAsync(); + + closeFuture.awaitClose(100, TimeUnit.SECONDS); + + assertTrue(closedRawChannel.get()); + assertTrue(closeFuture.isClosed()); + latch.await(); + assertSame(socketChannel, ref.get()); + } + + public void testConnectSucceeds() throws IOException, InterruptedException { + SocketChannel rawChannel = mock(SocketChannel.class); + when(rawChannel.finishConnect()).thenReturn(true); + NioSocketChannel socketChannel = new DoNotCloseChannel(NioChannel.CLIENT, rawChannel, selector); + socketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class)); + selector.scheduleForRegistration(socketChannel); + ConnectFuture connectFuture = socketChannel.getConnectFuture(); - connectFuture.awaitConnectionComplete(100, TimeUnit.SECONDS); + assertTrue(connectFuture.awaitConnectionComplete(100, TimeUnit.SECONDS)); + + assertTrue(socketChannel.isConnectComplete()); + assertTrue(socketChannel.isOpen()); + assertFalse(closedRawChannel.get()); + assertFalse(connectFuture.connectFailed()); + assertNull(connectFuture.getException()); + } + + public void testConnectFails() throws IOException, InterruptedException { + SocketChannel rawChannel = mock(SocketChannel.class); + when(rawChannel.finishConnect()).thenThrow(new ConnectException()); + NioSocketChannel socketChannel = new DoNotCloseChannel(NioChannel.CLIENT, rawChannel, selector); + socketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class)); + selector.scheduleForRegistration(socketChannel); + + ConnectFuture connectFuture = socketChannel.getConnectFuture(); + assertFalse(connectFuture.awaitConnectionComplete(100, TimeUnit.SECONDS)); assertFalse(socketChannel.isConnectComplete()); // Even if connection fails the channel is 'open' until close() is called assertTrue(socketChannel.isOpen()); assertTrue(connectFuture.connectFailed()); assertThat(connectFuture.getException(), instanceOf(ConnectException.class)); - - thread.join(); } - private void ensureConnect(NioSocketChannel nioSocketChannel) throws IOException { - for (;;) { - boolean isConnected = nioSocketChannel.finishConnect(); - if (isConnected) { - return; - } - LockSupport.parkNanos(TimeUnit.MILLISECONDS.toNanos(1)); + private class DoNotCloseChannel extends DoNotRegisterChannel { + + private DoNotCloseChannel(String profile, SocketChannel channel, SocketSelector selector) throws IOException { + super(profile, channel, selector); + } + + @Override + void closeRawChannel() throws IOException { + closedRawChannel.set(true); } } } diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/TcpWriteContextTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/TcpWriteContextTests.java index d2a2f446e73..171903a1b79 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/TcpWriteContextTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/TcpWriteContextTests.java @@ -186,104 +186,6 @@ public class TcpWriteContextTests extends ESTestCase { assertFalse(writeContext.hasQueuedWriteOps()); } - private class ConsumeAllChannel extends NioSocketChannel { - - private byte[] bytes; - private byte[] bytes2; - - ConsumeAllChannel() throws IOException { - super("", mock(SocketChannel.class)); - } - - public int write(ByteBuffer buffer) throws IOException { - bytes = new byte[buffer.remaining()]; - buffer.get(bytes); - return bytes.length; - } - - public long vectorizedWrite(ByteBuffer[] buffer) throws IOException { - if (buffer.length != 2) { - throw new IOException("Only allows 2 buffers"); - } - bytes = new byte[buffer[0].remaining()]; - buffer[0].get(bytes); - - bytes2 = new byte[buffer[1].remaining()]; - buffer[1].get(bytes2); - return bytes.length + bytes2.length; - } - } - - private class HalfConsumeChannel extends NioSocketChannel { - - private byte[] bytes; - private byte[] bytes2; - - HalfConsumeChannel() throws IOException { - super("", mock(SocketChannel.class)); - } - - public int write(ByteBuffer buffer) throws IOException { - bytes = new byte[buffer.limit() / 2]; - buffer.get(bytes); - return bytes.length; - } - - public long vectorizedWrite(ByteBuffer[] buffers) throws IOException { - if (buffers.length != 2) { - throw new IOException("Only allows 2 buffers"); - } - if (bytes == null) { - bytes = new byte[buffers[0].remaining()]; - bytes2 = new byte[buffers[1].remaining()]; - } - - if (buffers[0].remaining() != 0) { - buffers[0].get(bytes); - return bytes.length; - } else { - buffers[1].get(bytes2); - return bytes2.length; - } - } - } - - private class MultiWriteChannel extends NioSocketChannel { - - private byte[] write1Bytes; - private byte[] write1Bytes2; - private byte[] write2Bytes1; - private byte[] write2Bytes2; - - MultiWriteChannel() throws IOException { - super("", mock(SocketChannel.class)); - } - - public long vectorizedWrite(ByteBuffer[] buffers) throws IOException { - if (buffers.length != 4 && write1Bytes == null) { - throw new IOException("Only allows 4 buffers"); - } else if (buffers.length != 2 && write1Bytes != null) { - throw new IOException("Only allows 2 buffers on second write"); - } - if (write1Bytes == null) { - write1Bytes = new byte[buffers[0].remaining()]; - write1Bytes2 = new byte[buffers[1].remaining()]; - write2Bytes1 = new byte[buffers[2].remaining()]; - write2Bytes2 = new byte[buffers[3].remaining()]; - } - - if (buffers[0].remaining() != 0) { - buffers[0].get(write1Bytes); - buffers[1].get(write1Bytes2); - buffers[2].get(write2Bytes1); - return write1Bytes.length + write1Bytes2.length + write2Bytes1.length; - } else { - buffers[1].get(write2Bytes2); - return write2Bytes2.length; - } - } - } - private byte[] generateBytes(int n) { n += 10; byte[] bytes = new byte[n]; @@ -292,5 +194,4 @@ public class TcpWriteContextTests extends ESTestCase { } return bytes; } - }