diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/AbstractNioChannel.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/AbstractNioChannel.java deleted file mode 100644 index 14e2365eb7e..00000000000 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/AbstractNioChannel.java +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.nio; - -import java.io.IOException; -import java.net.InetSocketAddress; -import java.nio.channels.ClosedChannelException; -import java.nio.channels.NetworkChannel; -import java.nio.channels.SelectableChannel; -import java.nio.channels.SelectionKey; -import java.util.concurrent.CompletableFuture; -import java.util.function.BiConsumer; - -/** - * This is a basic channel abstraction used by the {@link ESSelector}. - *

- * A channel is open once it is constructed. The channel remains open and {@link #isOpen()} will return - * true until the channel is explicitly closed. - *

- * A channel lifecycle has two stages: - *

    - *
  1. OPEN - When a channel has been created. This is the state of a channel that can perform normal operations. - *
  2. CLOSED - The channel has been set to closed. All this means is that the channel has been scheduled to be - * closed. The underlying raw channel may not yet be closed. The underlying channel has been closed if the close - * future has been completed. - *
- * - * @param the type of raw channel this AbstractNioChannel uses - */ -public abstract class AbstractNioChannel implements NioChannel { - - final S socketChannel; - - private final InetSocketAddress localAddress; - private final CompletableFuture closeContext = new CompletableFuture<>(); - private final ESSelector selector; - private SelectionKey selectionKey; - - AbstractNioChannel(S socketChannel, ESSelector selector) throws IOException { - this.socketChannel = socketChannel; - this.localAddress = (InetSocketAddress) socketChannel.getLocalAddress(); - this.selector = selector; - } - - @Override - public boolean isOpen() { - return closeContext.isDone() == false; - } - - @Override - public InetSocketAddress getLocalAddress() { - return localAddress; - } - - /** - * Closes the channel synchronously. This method should only be called from the selector thread. - *

- * Once this method returns, the channel will be closed. - */ - @Override - public void closeFromSelector() throws IOException { - selector.assertOnSelectorThread(); - if (closeContext.isDone() == false) { - try { - socketChannel.close(); - closeContext.complete(null); - } catch (IOException e) { - closeContext.completeExceptionally(e); - throw e; - } - } - } - - /** - * This method attempts to registered a channel with the raw nio selector. It also sets the selection - * key. - * - * @throws ClosedChannelException if the raw channel was closed - */ - @Override - public void register() throws ClosedChannelException { - setSelectionKey(socketChannel.register(selector.rawSelector(), 0)); - } - - @Override - public ESSelector getSelector() { - return selector; - } - - @Override - public SelectionKey getSelectionKey() { - return selectionKey; - } - - @Override - public S getRawChannel() { - return socketChannel; - } - - @Override - public void addCloseListener(BiConsumer listener) { - closeContext.whenComplete(listener); - } - - @Override - public void close() { - getContext().closeChannel(); - } - - // Package visibility for testing - void setSelectionKey(SelectionKey selectionKey) { - this.selectionKey = selectionKey; - } -} diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/AcceptingSelector.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/AcceptingSelector.java index 2cbf7657e5d..da64020daa8 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/AcceptingSelector.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/AcceptingSelector.java @@ -24,6 +24,7 @@ import java.nio.channels.ClosedChannelException; import java.nio.channels.SelectionKey; import java.nio.channels.Selector; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.stream.Collectors; /** * Selector implementation that handles {@link NioServerSocketChannel}. It's main piece of functionality is @@ -46,12 +47,12 @@ public class AcceptingSelector extends ESSelector { @Override void processKey(SelectionKey selectionKey) { - NioServerSocketChannel serverChannel = (NioServerSocketChannel) selectionKey.attachment(); + ServerChannelContext channelContext = (ServerChannelContext) selectionKey.attachment(); if (selectionKey.isAcceptable()) { try { - eventHandler.acceptChannel(serverChannel); + eventHandler.acceptChannel(channelContext); } catch (IOException e) { - eventHandler.acceptException(serverChannel, e); + eventHandler.acceptException(channelContext, e); } } } @@ -63,7 +64,7 @@ public class AcceptingSelector extends ESSelector { @Override void cleanup() { - channelsToClose.addAll(newChannels); + channelsToClose.addAll(newChannels.stream().map(NioServerSocketChannel::getContext).collect(Collectors.toList())); } /** @@ -81,18 +82,16 @@ public class AcceptingSelector extends ESSelector { private void setUpNewServerChannels() { NioServerSocketChannel newChannel; while ((newChannel = this.newChannels.poll()) != null) { - assert newChannel.getSelector() == this : "The channel must be registered with the selector with which it was created"; + ServerChannelContext context = newChannel.getContext(); + assert context.getSelector() == this : "The channel must be registered with the selector with which it was created"; try { - if (newChannel.isOpen()) { - newChannel.register(); - SelectionKey selectionKey = newChannel.getSelectionKey(); - selectionKey.attach(newChannel); - eventHandler.serverChannelRegistered(newChannel); + if (context.isOpen()) { + eventHandler.handleRegistration(context); } else { - eventHandler.registrationException(newChannel, new ClosedChannelException()); + eventHandler.registrationException(context, new ClosedChannelException()); } - } catch (IOException e) { - eventHandler.registrationException(newChannel, e); + } catch (Exception e) { + eventHandler.registrationException(context, e); } } } diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/AcceptorEventHandler.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/AcceptorEventHandler.java index eb5194f21ef..474efad3c77 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/AcceptorEventHandler.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/AcceptorEventHandler.java @@ -23,6 +23,7 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import java.io.IOException; +import java.nio.channels.SelectionKey; import java.util.function.Supplier; /** @@ -38,46 +39,46 @@ public class AcceptorEventHandler extends EventHandler { } /** - * This method is called when a NioServerSocketChannel is successfully registered. It should only be - * called once per channel. + * This method is called when a NioServerSocketChannel is being registered with the selector. It should + * only be called once per channel. * - * @param nioServerSocketChannel that was registered + * @param context that was registered */ - protected void serverChannelRegistered(NioServerSocketChannel nioServerSocketChannel) { - SelectionKeyUtils.setAcceptInterested(nioServerSocketChannel); + protected void handleRegistration(ServerChannelContext context) throws IOException { + context.register(); + SelectionKey selectionKey = context.getSelectionKey(); + selectionKey.attach(context); + SelectionKeyUtils.setAcceptInterested(selectionKey); } /** * This method is called when an attempt to register a server channel throws an exception. * - * @param channel that was registered + * @param context that was registered * @param exception that occurred */ - protected void registrationException(NioServerSocketChannel channel, Exception exception) { - logger.error(new ParameterizedMessage("failed to register server channel: {}", channel), exception); + protected void registrationException(ServerChannelContext context, Exception exception) { + logger.error(new ParameterizedMessage("failed to register server channel: {}", context.getChannel()), exception); } /** * This method is called when a server channel signals it is ready to accept a connection. All of the * accept logic should occur in this call. * - * @param nioServerChannel that can accept a connection + * @param context that can accept a connection */ - protected void acceptChannel(NioServerSocketChannel nioServerChannel) throws IOException { - ChannelFactory channelFactory = nioServerChannel.getChannelFactory(); - SocketSelector selector = selectorSupplier.get(); - NioSocketChannel nioSocketChannel = channelFactory.acceptNioChannel(nioServerChannel, selector); - nioServerChannel.getContext().acceptChannel(nioSocketChannel); + protected void acceptChannel(ServerChannelContext context) throws IOException { + context.acceptChannels(selectorSupplier); } /** * This method is called when an attempt to accept a connection throws an exception. * - * @param nioServerChannel that accepting a connection + * @param context that accepting a connection * @param exception that occurred */ - protected void acceptException(NioServerSocketChannel nioServerChannel, Exception exception) { + protected void acceptException(ServerChannelContext context, Exception exception) { logger.debug(() -> new ParameterizedMessage("exception while accepting new channel from server channel: {}", - nioServerChannel), exception); + context.getChannel()), exception); } } diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesChannelContext.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesChannelContext.java index 5d77675aa48..000e871e927 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesChannelContext.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesChannelContext.java @@ -25,6 +25,7 @@ import java.nio.channels.ClosedChannelException; import java.util.LinkedList; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.BiConsumer; +import java.util.function.Consumer; public class BytesChannelContext extends SocketChannelContext { @@ -33,9 +34,9 @@ public class BytesChannelContext extends SocketChannelContext { private final LinkedList queued = new LinkedList<>(); private final AtomicBoolean isClosing = new AtomicBoolean(false); - public BytesChannelContext(NioSocketChannel channel, BiConsumer exceptionHandler, + public BytesChannelContext(NioSocketChannel channel, SocketSelector selector, Consumer exceptionHandler, ReadConsumer readConsumer, InboundChannelBuffer channelBuffer) { - super(channel, exceptionHandler); + super(channel, selector, exceptionHandler); this.readConsumer = readConsumer; this.channelBuffer = channelBuffer; } @@ -71,8 +72,8 @@ public class BytesChannelContext extends SocketChannelContext { return; } - BytesWriteOperation writeOperation = new BytesWriteOperation(channel, buffers, listener); - SocketSelector selector = channel.getSelector(); + BytesWriteOperation writeOperation = new BytesWriteOperation(this, buffers, listener); + SocketSelector selector = getSelector(); if (selector.isOnCurrentThread() == false) { selector.queueWrite(writeOperation); return; @@ -83,13 +84,13 @@ public class BytesChannelContext extends SocketChannelContext { @Override public void queueWriteOperation(WriteOperation writeOperation) { - channel.getSelector().assertOnSelectorThread(); + getSelector().assertOnSelectorThread(); queued.add((BytesWriteOperation) writeOperation); } @Override public void flushChannel() throws IOException { - channel.getSelector().assertOnSelectorThread(); + getSelector().assertOnSelectorThread(); int ops = queued.size(); if (ops == 1) { singleFlush(queued.pop()); @@ -100,14 +101,14 @@ public class BytesChannelContext extends SocketChannelContext { @Override public boolean hasQueuedWriteOps() { - channel.getSelector().assertOnSelectorThread(); + getSelector().assertOnSelectorThread(); return queued.isEmpty() == false; } @Override public void closeChannel() { if (isClosing.compareAndSet(false, true)) { - channel.getSelector().queueChannelClose(channel); + getSelector().queueChannelClose(channel); } } @@ -118,11 +119,11 @@ public class BytesChannelContext extends SocketChannelContext { @Override public void closeFromSelector() throws IOException { - channel.getSelector().assertOnSelectorThread(); + getSelector().assertOnSelectorThread(); if (channel.isOpen()) { IOException channelCloseException = null; try { - channel.closeFromSelector(); + super.closeFromSelector(); } catch (IOException e) { channelCloseException = e; } @@ -130,7 +131,7 @@ public class BytesChannelContext extends SocketChannelContext { isClosing.set(true); channelBuffer.close(); for (BytesWriteOperation op : queued) { - channel.getSelector().executeFailedListener(op.getListener(), new ClosedChannelException()); + getSelector().executeFailedListener(op.getListener(), new ClosedChannelException()); } queued.clear(); if (channelCloseException != null) { @@ -144,12 +145,12 @@ public class BytesChannelContext extends SocketChannelContext { int written = flushToChannel(headOp.getBuffersToWrite()); headOp.incrementIndex(written); } catch (IOException e) { - channel.getSelector().executeFailedListener(headOp.getListener(), e); + getSelector().executeFailedListener(headOp.getListener(), e); throw e; } if (headOp.isFullyFlushed()) { - channel.getSelector().executeListener(headOp.getListener(), null); + getSelector().executeListener(headOp.getListener(), null); } else { queued.push(headOp); } diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesWriteOperation.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesWriteOperation.java index 14e8cace66d..37c6e497276 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesWriteOperation.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesWriteOperation.java @@ -25,15 +25,15 @@ import java.util.function.BiConsumer; public class BytesWriteOperation implements WriteOperation { - private final NioSocketChannel channel; + private final SocketChannelContext channelContext; private final BiConsumer listener; private final ByteBuffer[] buffers; private final int[] offsets; private final int length; private int internalIndex; - public BytesWriteOperation(NioSocketChannel channel, ByteBuffer[] buffers, BiConsumer listener) { - this.channel = channel; + public BytesWriteOperation(SocketChannelContext channelContext, ByteBuffer[] buffers, BiConsumer listener) { + this.channelContext = channelContext; this.listener = listener; this.buffers = buffers; this.offsets = new int[buffers.length]; @@ -52,8 +52,8 @@ public class BytesWriteOperation implements WriteOperation { } @Override - public NioSocketChannel getChannel() { - return channel; + public SocketChannelContext getChannel() { + return channelContext; } public boolean isFullyFlushed() { diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ChannelContext.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ChannelContext.java index fa664484c1c..01f35347aa4 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ChannelContext.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ChannelContext.java @@ -20,15 +20,78 @@ package org.elasticsearch.nio; import java.io.IOException; +import java.nio.channels.NetworkChannel; +import java.nio.channels.SelectableChannel; +import java.nio.channels.SelectionKey; +import java.util.concurrent.CompletableFuture; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +/** + * Implements the logic related to interacting with a java.nio channel. For example: registering with a + * selector, managing the selection key, closing, etc is implemented by this class or its subclasses. + * + * @param the type of channel + */ +public abstract class ChannelContext { + + protected final S rawChannel; + private final Consumer exceptionHandler; + private final CompletableFuture closeContext = new CompletableFuture<>(); + private volatile SelectionKey selectionKey; + + ChannelContext(S rawChannel, Consumer exceptionHandler) { + this.rawChannel = rawChannel; + this.exceptionHandler = exceptionHandler; + } + + protected void register() throws IOException { + setSelectionKey(rawChannel.register(getSelector().rawSelector(), 0)); + } + + SelectionKey getSelectionKey() { + return selectionKey; + } + + // Protected for tests + protected void setSelectionKey(SelectionKey selectionKey) { + this.selectionKey = selectionKey; + } -public interface ChannelContext { /** * This method cleans up any context resources that need to be released when a channel is closed. It * should only be called by the selector thread. * * @throws IOException during channel / context close */ - void closeFromSelector() throws IOException; + public void closeFromSelector() throws IOException { + if (closeContext.isDone() == false) { + try { + rawChannel.close(); + closeContext.complete(null); + } catch (Exception e) { + closeContext.completeExceptionally(e); + throw e; + } + } + } + + /** + * Add a listener that will be called when the channel is closed. + * + * @param listener to be called + */ + public void addCloseListener(BiConsumer listener) { + closeContext.whenComplete(listener); + } + + public boolean isOpen() { + return closeContext.isDone() == false; + } + + void handleException(Exception e) { + exceptionHandler.accept(e); + } /** * Schedules a channel to be closed by the selector event loop with which it is registered. @@ -39,7 +102,10 @@ public interface ChannelContext { * Depending on the underlying protocol of the channel, a close operation might simply close the socket * channel or may involve reading and writing messages. */ - void closeChannel(); + public abstract void closeChannel(); + + public abstract ESSelector getSelector(); + + public abstract NioChannel getChannel(); - void handleException(Exception e); } diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ChannelFactory.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ChannelFactory.java index 5fc3f46f998..a03c4bcc15b 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ChannelFactory.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ChannelFactory.java @@ -27,6 +27,7 @@ import java.nio.channels.SocketChannel; import java.security.AccessController; import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; +import java.util.function.Supplier; public abstract class ChannelFactory { @@ -41,22 +42,30 @@ public abstract class ChannelFactory supplier) throws IOException { SocketChannel rawChannel = rawChannelFactory.openNioChannel(remoteAddress); + SocketSelector selector = supplier.get(); Socket channel = internalCreateChannel(selector, rawChannel); scheduleChannel(channel, selector); return channel; } - public Socket acceptNioChannel(NioServerSocketChannel serverChannel, SocketSelector selector) throws IOException { - SocketChannel rawChannel = rawChannelFactory.acceptNioChannel(serverChannel); - Socket channel = internalCreateChannel(selector, rawChannel); - scheduleChannel(channel, selector); - return channel; + public Socket acceptNioChannel(ServerChannelContext serverContext, Supplier supplier) throws IOException { + SocketChannel rawChannel = rawChannelFactory.acceptNioChannel(serverContext); + // Null is returned if there are no pending sockets to accept + if (rawChannel == null) { + return null; + } else { + SocketSelector selector = supplier.get(); + Socket channel = internalCreateChannel(selector, rawChannel); + scheduleChannel(channel, selector); + return channel; + } } - public ServerSocket openNioServerSocketChannel(InetSocketAddress address, AcceptingSelector selector) throws IOException { + public ServerSocket openNioServerSocketChannel(InetSocketAddress address, Supplier supplier) throws IOException { ServerSocketChannel rawChannel = rawChannelFactory.openNioServerSocketChannel(address); + AcceptingSelector selector = supplier.get(); ServerSocket serverChannel = internalCreateServerChannel(selector, rawChannel); scheduleServerChannel(serverChannel, selector); return serverChannel; @@ -140,7 +149,7 @@ public abstract class ChannelFactory channelsToClose = new ConcurrentLinkedQueue<>(); + final ConcurrentLinkedQueue> channelsToClose = new ConcurrentLinkedQueue<>(); private final EventHandler eventHandler; private final ReentrantLock runLock = new ReentrantLock(); @@ -60,7 +60,7 @@ public abstract class ESSelector implements Closeable { this(eventHandler, Selector.open()); } - ESSelector(EventHandler eventHandler, Selector selector) throws IOException { + ESSelector(EventHandler eventHandler, Selector selector) { this.eventHandler = eventHandler; this.selector = selector; } @@ -111,10 +111,10 @@ public abstract class ESSelector implements Closeable { try { processKey(sk); } catch (CancelledKeyException cke) { - eventHandler.genericChannelException((NioChannel) sk.attachment(), cke); + eventHandler.genericChannelException((ChannelContext) sk.attachment(), cke); } } else { - eventHandler.genericChannelException((NioChannel) sk.attachment(), new CancelledKeyException()); + eventHandler.genericChannelException((ChannelContext) sk.attachment(), new CancelledKeyException()); } } } @@ -131,7 +131,7 @@ public abstract class ESSelector implements Closeable { void cleanupAndCloseChannels() { cleanup(); - channelsToClose.addAll(selector.keys().stream().map(sk -> (NioChannel) sk.attachment()).collect(Collectors.toList())); + channelsToClose.addAll(selector.keys().stream().map(sk -> (ChannelContext) sk.attachment()).collect(Collectors.toList())); closePendingChannels(); } @@ -191,9 +191,10 @@ public abstract class ESSelector implements Closeable { } public void queueChannelClose(NioChannel channel) { - assert channel.getSelector() == this : "Must schedule a channel for closure with its selector"; - channelsToClose.offer(channel); - ensureSelectorOpenForEnqueuing(channelsToClose, channel); + ChannelContext context = channel.getContext(); + assert context.getSelector() == this : "Must schedule a channel for closure with its selector"; + channelsToClose.offer(context); + ensureSelectorOpenForEnqueuing(channelsToClose, context); wakeup(); } @@ -239,9 +240,9 @@ public abstract class ESSelector implements Closeable { } private void closePendingChannels() { - NioChannel channel; - while ((channel = channelsToClose.poll()) != null) { - eventHandler.handleClose(channel); + ChannelContext channelContext; + while ((channelContext = channelsToClose.poll()) != null) { + eventHandler.handleClose(channelContext); } } } diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/EventHandler.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/EventHandler.java index 7cba9b998b3..d35b73c56b8 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/EventHandler.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/EventHandler.java @@ -65,25 +65,25 @@ public abstract class EventHandler { /** * This method handles the closing of an NioChannel * - * @param channel that should be closed + * @param context that should be closed */ - protected void handleClose(NioChannel channel) { + protected void handleClose(ChannelContext context) { try { - channel.getContext().closeFromSelector(); + context.closeFromSelector(); } catch (IOException e) { - closeException(channel, e); + closeException(context, e); } - assert channel.isOpen() == false : "Should always be done as we are on the selector thread"; + assert context.isOpen() == false : "Should always be done as we are on the selector thread"; } /** * This method is called when an attempt to close a channel throws an exception. * - * @param channel that was being closed + * @param context that was being closed * @param exception that occurred */ - protected void closeException(NioChannel channel, Exception exception) { - logger.debug(() -> new ParameterizedMessage("exception while closing channel: {}", channel), exception); + protected void closeException(ChannelContext context, Exception exception) { + logger.debug(() -> new ParameterizedMessage("exception while closing channel: {}", context.getChannel()), exception); } /** @@ -94,7 +94,7 @@ public abstract class EventHandler { * @param channel that caused the exception * @param exception that was thrown */ - protected void genericChannelException(NioChannel channel, Exception exception) { - logger.debug(() -> new ParameterizedMessage("exception while handling event for channel: {}", channel), exception); + protected void genericChannelException(ChannelContext channel, Exception exception) { + logger.debug(() -> new ParameterizedMessage("exception while handling event for channel: {}", channel.getChannel()), exception); } } diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/NioChannel.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/NioChannel.java index 690e3d3b38b..2f9705f5f8f 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/NioChannel.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/NioChannel.java @@ -21,30 +21,30 @@ package org.elasticsearch.nio; import java.io.IOException; import java.net.InetSocketAddress; -import java.nio.channels.ClosedChannelException; import java.nio.channels.NetworkChannel; -import java.nio.channels.SelectionKey; import java.util.function.BiConsumer; -public interface NioChannel { +/** + * This is a basic channel abstraction used by the {@link ESSelector}. + *

+ * A channel is open once it is constructed. The channel remains open and {@link #isOpen()} will return + * true until the channel is explicitly closed. + */ +public abstract class NioChannel { - boolean isOpen(); + private final InetSocketAddress localAddress; - InetSocketAddress getLocalAddress(); + NioChannel(NetworkChannel socketChannel) throws IOException { + this.localAddress = (InetSocketAddress) socketChannel.getLocalAddress(); + } - void close(); + public boolean isOpen() { + return getContext().isOpen(); + } - void closeFromSelector() throws IOException; - - void register() throws ClosedChannelException; - - ESSelector getSelector(); - - SelectionKey getSelectionKey(); - - NetworkChannel getRawChannel(); - - ChannelContext getContext(); + public InetSocketAddress getLocalAddress() { + return localAddress; + } /** * Adds a close listener to the channel. Multiple close listeners can be added. There is no guarantee @@ -53,5 +53,18 @@ public interface NioChannel { * * @param listener to be called at close */ - void addCloseListener(BiConsumer listener); + public void addCloseListener(BiConsumer listener) { + getContext().addCloseListener(listener); + } + + /** + * Schedules channel for close. This process is asynchronous. + */ + public void close() { + getContext().closeChannel(); + } + + public abstract NetworkChannel getRawChannel(); + + public abstract ChannelContext getContext(); } diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/NioGroup.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/NioGroup.java index 109d22c45fd..b7637656162 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/NioGroup.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/NioGroup.java @@ -96,12 +96,12 @@ public class NioGroup implements AutoCloseable { if (acceptors.isEmpty()) { throw new IllegalArgumentException("There are no acceptors configured. Without acceptors, server channels are not supported."); } - return factory.openNioServerSocketChannel(address, acceptorSupplier.get()); + return factory.openNioServerSocketChannel(address, acceptorSupplier); } public S openChannel(InetSocketAddress address, ChannelFactory factory) throws IOException { ensureOpen(); - return factory.openNioChannel(address, socketSelectorSupplier.get()); + return factory.openNioChannel(address, socketSelectorSupplier); } @Override diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/NioServerSocketChannel.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/NioServerSocketChannel.java index 3d1748e413a..9f78c3b1b31 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/NioServerSocketChannel.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/NioServerSocketChannel.java @@ -23,20 +23,15 @@ import java.io.IOException; import java.nio.channels.ServerSocketChannel; import java.util.concurrent.atomic.AtomicBoolean; -public class NioServerSocketChannel extends AbstractNioChannel { +public class NioServerSocketChannel extends NioChannel { - private final ChannelFactory channelFactory; - private ServerChannelContext context; + private final ServerSocketChannel socketChannel; private final AtomicBoolean contextSet = new AtomicBoolean(false); + private ServerChannelContext context; - public NioServerSocketChannel(ServerSocketChannel socketChannel, ChannelFactory channelFactory, AcceptingSelector selector) - throws IOException { - super(socketChannel, selector); - this.channelFactory = channelFactory; - } - - public ChannelFactory getChannelFactory() { - return channelFactory; + public NioServerSocketChannel(ServerSocketChannel socketChannel) throws IOException { + super(socketChannel); + this.socketChannel = socketChannel; } /** @@ -53,6 +48,11 @@ public class NioServerSocketChannel extends AbstractNioChannel { +public class NioSocketChannel extends NioChannel { private final InetSocketAddress remoteAddress; - private final CompletableFuture connectContext = new CompletableFuture<>(); - private final SocketSelector socketSelector; private final AtomicBoolean contextSet = new AtomicBoolean(false); + private final SocketChannel socketChannel; private SocketChannelContext context; - private Exception connectException; - public NioSocketChannel(SocketChannel socketChannel, SocketSelector selector) throws IOException { - super(socketChannel, selector); + public NioSocketChannel(SocketChannel socketChannel) throws IOException { + super(socketChannel); + this.socketChannel = socketChannel; this.remoteAddress = (InetSocketAddress) socketChannel.getRemoteAddress(); - this.socketSelector = selector; - } - - @Override - public SocketSelector getSelector() { - return socketSelector; - } - - public int write(ByteBuffer buffer) throws IOException { - return socketChannel.write(buffer); - } - - public int write(ByteBuffer[] buffers) throws IOException { - if (buffers.length == 1) { - return socketChannel.write(buffers[0]); - } else { - return (int) socketChannel.write(buffers); - } - } - - public int read(ByteBuffer buffer) throws IOException { - return socketChannel.read(buffer); - } - - public int read(ByteBuffer[] buffers) throws IOException { - if (buffers.length == 1) { - return socketChannel.read(buffers[0]); - } else { - return (int) socketChannel.read(buffers); - } } public void setContext(SocketChannelContext context) { @@ -79,6 +46,11 @@ public class NioSocketChannel extends AbstractNioChannel { } } + @Override + public SocketChannel getRawChannel() { + return socketChannel; + } + @Override public SocketChannelContext getContext() { return context; @@ -88,46 +60,8 @@ public class NioSocketChannel extends AbstractNioChannel { return remoteAddress; } - public boolean isConnectComplete() { - return isConnectComplete0(); - } - - /** - * This method will attempt to complete the connection process for this channel. It should be called for - * new channels or for a channel that has produced a OP_CONNECT event. If this method returns true then - * the connection is complete and the channel is ready for reads and writes. If it returns false, the - * channel is not yet connected and this method should be called again when a OP_CONNECT event is - * received. - * - * @return true if the connection process is complete - * @throws IOException if an I/O error occurs - */ - public boolean finishConnect() throws IOException { - if (isConnectComplete0()) { - return true; - } else if (connectContext.isCompletedExceptionally()) { - Exception exception = connectException; - if (exception == null) { - throw new AssertionError("Should have received connection exception"); - } else if (exception instanceof IOException) { - throw (IOException) exception; - } else { - throw (RuntimeException) exception; - } - } - - boolean isConnected = socketChannel.isConnected(); - if (isConnected == false) { - isConnected = internalFinish(); - } - if (isConnected) { - connectContext.complete(null); - } - return isConnected; - } - public void addConnectListener(BiConsumer listener) { - connectContext.whenComplete(listener); + context.addConnectListener(listener); } @Override @@ -137,18 +71,4 @@ public class NioSocketChannel extends AbstractNioChannel { ", remoteAddress=" + remoteAddress + '}'; } - - private boolean internalFinish() throws IOException { - try { - return socketChannel.finishConnect(); - } catch (IOException | RuntimeException e) { - connectException = e; - connectContext.completeExceptionally(e); - throw e; - } - } - - private boolean isConnectComplete0() { - return connectContext.isDone() && connectContext.isCompletedExceptionally() == false; - } } diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SelectionKeyUtils.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SelectionKeyUtils.java index be2dc6f3414..93d58344de6 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SelectionKeyUtils.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SelectionKeyUtils.java @@ -27,80 +27,75 @@ public final class SelectionKeyUtils { private SelectionKeyUtils() {} /** - * Adds an interest in writes for this channel while maintaining other interests. + * Adds an interest in writes for this selection key while maintaining other interests. * - * @param channel the channel + * @param selectionKey the selection key * @throws CancelledKeyException if the key was already cancelled */ - public static void setWriteInterested(NioChannel channel) throws CancelledKeyException { - SelectionKey selectionKey = channel.getSelectionKey(); + public static void setWriteInterested(SelectionKey selectionKey) throws CancelledKeyException { selectionKey.interestOps(selectionKey.interestOps() | SelectionKey.OP_WRITE); } /** - * Removes an interest in writes for this channel while maintaining other interests. + * Removes an interest in writes for this selection key while maintaining other interests. * - * @param channel the channel + * @param selectionKey the selection key * @throws CancelledKeyException if the key was already cancelled */ - public static void removeWriteInterested(NioChannel channel) throws CancelledKeyException { - SelectionKey selectionKey = channel.getSelectionKey(); + public static void removeWriteInterested(SelectionKey selectionKey) throws CancelledKeyException { selectionKey.interestOps(selectionKey.interestOps() & ~SelectionKey.OP_WRITE); } /** - * Removes an interest in connects and reads for this channel while maintaining other interests. + * Removes an interest in connects and reads for this selection key while maintaining other interests. * - * @param channel the channel + * @param selectionKey the selection key * @throws CancelledKeyException if the key was already cancelled */ - public static void setConnectAndReadInterested(NioChannel channel) throws CancelledKeyException { - SelectionKey selectionKey = channel.getSelectionKey(); + public static void setConnectAndReadInterested(SelectionKey selectionKey) throws CancelledKeyException { selectionKey.interestOps(selectionKey.interestOps() | SelectionKey.OP_CONNECT | SelectionKey.OP_READ); } /** - * Removes an interest in connects, reads, and writes for this channel while maintaining other interests. + * Removes an interest in connects, reads, and writes for this selection key while maintaining other + * interests. * - * @param channel the channel + * @param selectionKey the selection key * @throws CancelledKeyException if the key was already cancelled */ - public static void setConnectReadAndWriteInterested(NioChannel channel) throws CancelledKeyException { - SelectionKey selectionKey = channel.getSelectionKey(); + public static void setConnectReadAndWriteInterested(SelectionKey selectionKey) throws CancelledKeyException { selectionKey.interestOps(selectionKey.interestOps() | SelectionKey.OP_CONNECT | SelectionKey.OP_READ | SelectionKey.OP_WRITE); } /** - * Removes an interest in connects for this channel while maintaining other interests. + * Removes an interest in connects for this selection key while maintaining other interests. * - * @param channel the channel + * @param selectionKey the selection key * @throws CancelledKeyException if the key was already cancelled */ - public static void removeConnectInterested(NioChannel channel) throws CancelledKeyException { - SelectionKey selectionKey = channel.getSelectionKey(); + public static void removeConnectInterested(SelectionKey selectionKey) throws CancelledKeyException { selectionKey.interestOps(selectionKey.interestOps() & ~SelectionKey.OP_CONNECT); } /** - * Adds an interest in accepts for this channel while maintaining other interests. + * Adds an interest in accepts for this selection key while maintaining other interests. * - * @param channel the channel + * @param selectionKey the selection key * @throws CancelledKeyException if the key was already cancelled */ - public static void setAcceptInterested(NioServerSocketChannel channel) throws CancelledKeyException { - SelectionKey selectionKey = channel.getSelectionKey(); + public static void setAcceptInterested(SelectionKey selectionKey) throws CancelledKeyException { selectionKey.interestOps(selectionKey.interestOps() | SelectionKey.OP_ACCEPT); } /** - * Checks for an interest in writes for this channel. + * Checks for an interest in writes for this selection key. * - * @param channel the channel + * @param selectionKey the selection key * @return a boolean indicating if we are currently interested in writes for this channel * @throws CancelledKeyException if the key was already cancelled */ - public static boolean isWriteInterested(NioSocketChannel channel) throws CancelledKeyException { - return (channel.getSelectionKey().interestOps() & SelectionKey.OP_WRITE) != 0; + public static boolean isWriteInterested(SelectionKey selectionKey) throws CancelledKeyException { + return (selectionKey.interestOps() & SelectionKey.OP_WRITE) != 0; } } diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ServerChannelContext.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ServerChannelContext.java index 551cab48e05..4b47ce063f9 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ServerChannelContext.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ServerChannelContext.java @@ -20,43 +20,50 @@ package org.elasticsearch.nio; import java.io.IOException; +import java.nio.channels.ServerSocketChannel; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.BiConsumer; import java.util.function.Consumer; +import java.util.function.Supplier; -public class ServerChannelContext implements ChannelContext { +public class ServerChannelContext extends ChannelContext { private final NioServerSocketChannel channel; + private final AcceptingSelector selector; private final Consumer acceptor; - private final BiConsumer exceptionHandler; private final AtomicBoolean isClosing = new AtomicBoolean(false); + private final ChannelFactory channelFactory; - public ServerChannelContext(NioServerSocketChannel channel, Consumer acceptor, - BiConsumer exceptionHandler) { + public ServerChannelContext(NioServerSocketChannel channel, ChannelFactory channelFactory, AcceptingSelector selector, + Consumer acceptor, Consumer exceptionHandler) { + super(channel.getRawChannel(), exceptionHandler); this.channel = channel; + this.channelFactory = channelFactory; + this.selector = selector; this.acceptor = acceptor; - this.exceptionHandler = exceptionHandler; } - public void acceptChannel(NioSocketChannel acceptedChannel) { - acceptor.accept(acceptedChannel); - } - - @Override - public void closeFromSelector() throws IOException { - channel.closeFromSelector(); + public void acceptChannels(Supplier selectorSupplier) throws IOException { + NioSocketChannel acceptedChannel; + while ((acceptedChannel = channelFactory.acceptNioChannel(this, selectorSupplier)) != null) { + acceptor.accept(acceptedChannel); + } } @Override public void closeChannel() { if (isClosing.compareAndSet(false, true)) { - channel.getSelector().queueChannelClose(channel); + getSelector().queueChannelClose(channel); } } @Override - public void handleException(Exception e) { - exceptionHandler.accept(channel, e); + public AcceptingSelector getSelector() { + return selector; } + + @Override + public NioServerSocketChannel getChannel() { + return channel; + } + } diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java index 62f82e8995d..3bf47a98e02 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java @@ -21,7 +21,10 @@ package org.elasticsearch.nio; import java.io.IOException; import java.nio.ByteBuffer; +import java.nio.channels.SocketChannel; +import java.util.concurrent.CompletableFuture; import java.util.function.BiConsumer; +import java.util.function.Consumer; /** * This context should implement the specific logic for a channel. When a channel receives a notification @@ -32,24 +35,78 @@ import java.util.function.BiConsumer; * The only methods of the context that should ever be called from a non-selector thread are * {@link #closeChannel()} and {@link #sendMessage(ByteBuffer[], BiConsumer)}. */ -public abstract class SocketChannelContext implements ChannelContext { +public abstract class SocketChannelContext extends ChannelContext { protected final NioSocketChannel channel; - private final BiConsumer exceptionHandler; + private final SocketSelector selector; + private final CompletableFuture connectContext = new CompletableFuture<>(); private boolean ioException; private boolean peerClosed; + private Exception connectException; - protected SocketChannelContext(NioSocketChannel channel, BiConsumer exceptionHandler) { + protected SocketChannelContext(NioSocketChannel channel, SocketSelector selector, Consumer exceptionHandler) { + super(channel.getRawChannel(), exceptionHandler); + this.selector = selector; this.channel = channel; - this.exceptionHandler = exceptionHandler; } @Override - public void handleException(Exception e) { - exceptionHandler.accept(channel, e); + public SocketSelector getSelector() { + return selector; } - public void channelRegistered() throws IOException {} + @Override + public NioSocketChannel getChannel() { + return channel; + } + + public void addConnectListener(BiConsumer listener) { + connectContext.whenComplete(listener); + } + + public boolean isConnectComplete() { + return connectContext.isDone() && connectContext.isCompletedExceptionally() == false; + } + + /** + * This method will attempt to complete the connection process for this channel. It should be called for + * new channels or for a channel that has produced a OP_CONNECT event. If this method returns true then + * the connection is complete and the channel is ready for reads and writes. If it returns false, the + * channel is not yet connected and this method should be called again when a OP_CONNECT event is + * received. + * + * @return true if the connection process is complete + * @throws IOException if an I/O error occurs + */ + public boolean connect() throws IOException { + if (isConnectComplete()) { + return true; + } else if (connectContext.isCompletedExceptionally()) { + Exception exception = connectException; + if (exception == null) { + throw new AssertionError("Should have received connection exception"); + } else if (exception instanceof IOException) { + throw (IOException) exception; + } else { + throw (RuntimeException) exception; + } + } + + boolean isConnected = rawChannel.isConnected(); + if (isConnected == false) { + try { + isConnected = rawChannel.finishConnect(); + } catch (IOException | RuntimeException e) { + connectException = e; + connectContext.completeExceptionally(e); + throw e; + } + } + if (isConnected) { + connectContext.complete(null); + } + return isConnected; + } public abstract int read() throws IOException; @@ -78,7 +135,7 @@ public abstract class SocketChannelContext implements ChannelContext { protected int readFromChannel(ByteBuffer buffer) throws IOException { try { - int bytesRead = channel.read(buffer); + int bytesRead = rawChannel.read(buffer); if (bytesRead < 0) { peerClosed = true; bytesRead = 0; @@ -92,7 +149,7 @@ public abstract class SocketChannelContext implements ChannelContext { protected int readFromChannel(ByteBuffer[] buffers) throws IOException { try { - int bytesRead = channel.read(buffers); + int bytesRead = (int) rawChannel.read(buffers); if (bytesRead < 0) { peerClosed = true; bytesRead = 0; @@ -106,7 +163,7 @@ public abstract class SocketChannelContext implements ChannelContext { protected int flushToChannel(ByteBuffer buffer) throws IOException { try { - return channel.write(buffer); + return rawChannel.write(buffer); } catch (IOException e) { ioException = true; throw e; @@ -115,7 +172,7 @@ public abstract class SocketChannelContext implements ChannelContext { protected int flushToChannel(ByteBuffer[] buffers) throws IOException { try { - return channel.write(buffers); + return (int) rawChannel.write(buffers); } catch (IOException e) { ioException = true; throw e; diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketEventHandler.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketEventHandler.java index b1192f11eb1..b1f73864761 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketEventHandler.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketEventHandler.java @@ -23,6 +23,7 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import java.io.IOException; +import java.nio.channels.SelectionKey; import java.util.function.BiConsumer; /** @@ -41,91 +42,93 @@ public class SocketEventHandler extends EventHandler { * This method is called when a NioSocketChannel is successfully registered. It should only be called * once per channel. * - * @param channel that was registered + * @param context that was registered */ - protected void handleRegistration(NioSocketChannel channel) throws IOException { - SocketChannelContext context = channel.getContext(); - context.channelRegistered(); + protected void handleRegistration(SocketChannelContext context) throws IOException { + context.register(); + SelectionKey selectionKey = context.getSelectionKey(); + selectionKey.attach(context); if (context.hasQueuedWriteOps()) { - SelectionKeyUtils.setConnectReadAndWriteInterested(channel); + SelectionKeyUtils.setConnectReadAndWriteInterested(selectionKey); } else { - SelectionKeyUtils.setConnectAndReadInterested(channel); + SelectionKeyUtils.setConnectAndReadInterested(selectionKey); } } /** * This method is called when an attempt to register a channel throws an exception. * - * @param channel that was registered + * @param context that was registered * @param exception that occurred */ - protected void registrationException(NioSocketChannel channel, Exception exception) { - logger.debug(() -> new ParameterizedMessage("failed to register socket channel: {}", channel), exception); - channel.getContext().handleException(exception); + protected void registrationException(SocketChannelContext context, Exception exception) { + logger.debug(() -> new ParameterizedMessage("failed to register socket channel: {}", context.getChannel()), exception); + context.handleException(exception); } /** - * This method is called when a NioSocketChannel is successfully connected. It should only be called - * once per channel. + * This method is called when a NioSocketChannel has just been accepted or if it has receive an + * OP_CONNECT event. * - * @param channel that was registered + * @param context that was registered */ - protected void handleConnect(NioSocketChannel channel) { - SelectionKeyUtils.removeConnectInterested(channel); + protected void handleConnect(SocketChannelContext context) throws IOException { + if (context.connect()) { + SelectionKeyUtils.removeConnectInterested(context.getSelectionKey()); + } } /** * This method is called when an attempt to connect a channel throws an exception. * - * @param channel that was connecting + * @param context that was connecting * @param exception that occurred */ - protected void connectException(NioSocketChannel channel, Exception exception) { - logger.debug(() -> new ParameterizedMessage("failed to connect to socket channel: {}", channel), exception); - channel.getContext().handleException(exception); + protected void connectException(SocketChannelContext context, Exception exception) { + logger.debug(() -> new ParameterizedMessage("failed to connect to socket channel: {}", context.getChannel()), exception); + context.handleException(exception); } /** * This method is called when a channel signals it is ready for be read. All of the read logic should * occur in this call. * - * @param channel that can be read + * @param context that can be read */ - protected void handleRead(NioSocketChannel channel) throws IOException { - channel.getContext().read(); + protected void handleRead(SocketChannelContext context) throws IOException { + context.read(); } /** * This method is called when an attempt to read from a channel throws an exception. * - * @param channel that was being read + * @param context that was being read * @param exception that occurred */ - protected void readException(NioSocketChannel channel, Exception exception) { - logger.debug(() -> new ParameterizedMessage("exception while reading from socket channel: {}", channel), exception); - channel.getContext().handleException(exception); + protected void readException(SocketChannelContext context, Exception exception) { + logger.debug(() -> new ParameterizedMessage("exception while reading from socket channel: {}", context.getChannel()), exception); + context.handleException(exception); } /** * This method is called when a channel signals it is ready to receive writes. All of the write logic * should occur in this call. * - * @param channel that can be written to + * @param context that can be written to */ - protected void handleWrite(NioSocketChannel channel) throws IOException { - SocketChannelContext channelContext = channel.getContext(); - channelContext.flushChannel(); + protected void handleWrite(SocketChannelContext context) throws IOException { + context.flushChannel(); } /** * This method is called when an attempt to write to a channel throws an exception. * - * @param channel that was being written to + * @param context that was being written to * @param exception that occurred */ - protected void writeException(NioSocketChannel channel, Exception exception) { - logger.debug(() -> new ParameterizedMessage("exception while writing to socket channel: {}", channel), exception); - channel.getContext().handleException(exception); + protected void writeException(SocketChannelContext context, Exception exception) { + logger.debug(() -> new ParameterizedMessage("exception while writing to socket channel: {}", context.getChannel()), exception); + context.handleException(exception); } /** @@ -139,18 +142,19 @@ public class SocketEventHandler extends EventHandler { } /** - * @param channel that was handled + * @param context that was handled */ - protected void postHandling(NioSocketChannel channel) { - if (channel.getContext().selectorShouldClose()) { - handleClose(channel); + protected void postHandling(SocketChannelContext context) { + if (context.selectorShouldClose()) { + handleClose(context); } else { - boolean currentlyWriteInterested = SelectionKeyUtils.isWriteInterested(channel); - boolean pendingWrites = channel.getContext().hasQueuedWriteOps(); + SelectionKey selectionKey = context.getSelectionKey(); + boolean currentlyWriteInterested = SelectionKeyUtils.isWriteInterested(selectionKey); + boolean pendingWrites = context.hasQueuedWriteOps(); if (currentlyWriteInterested == false && pendingWrites) { - SelectionKeyUtils.setWriteInterested(channel); + SelectionKeyUtils.setWriteInterested(selectionKey); } else if (currentlyWriteInterested && pendingWrites == false) { - SelectionKeyUtils.removeWriteInterested(channel); + SelectionKeyUtils.removeWriteInterested(selectionKey); } } } diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketSelector.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketSelector.java index acfec6ca04e..b1a3a08f02d 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketSelector.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketSelector.java @@ -33,7 +33,7 @@ import java.util.function.BiConsumer; */ public class SocketSelector extends ESSelector { - private final ConcurrentLinkedQueue newChannels = new ConcurrentLinkedQueue<>(); + private final ConcurrentLinkedQueue newChannels = new ConcurrentLinkedQueue<>(); private final ConcurrentLinkedQueue queuedWrites = new ConcurrentLinkedQueue<>(); private final SocketEventHandler eventHandler; @@ -49,23 +49,23 @@ public class SocketSelector extends ESSelector { @Override void processKey(SelectionKey selectionKey) { - NioSocketChannel nioSocketChannel = (NioSocketChannel) selectionKey.attachment(); + SocketChannelContext channelContext = (SocketChannelContext) selectionKey.attachment(); int ops = selectionKey.readyOps(); if ((ops & SelectionKey.OP_CONNECT) != 0) { - attemptConnect(nioSocketChannel, true); + attemptConnect(channelContext, true); } - if (nioSocketChannel.isConnectComplete()) { + if (channelContext.isConnectComplete()) { if ((ops & SelectionKey.OP_WRITE) != 0) { - handleWrite(nioSocketChannel); + handleWrite(channelContext); } if ((ops & SelectionKey.OP_READ) != 0) { - handleRead(nioSocketChannel); + handleRead(channelContext); } } - eventHandler.postHandling(nioSocketChannel); + eventHandler.postHandling(channelContext); } @Override @@ -89,8 +89,9 @@ public class SocketSelector extends ESSelector { * @param nioSocketChannel the channel to register */ public void scheduleForRegistration(NioSocketChannel nioSocketChannel) { - newChannels.offer(nioSocketChannel); - ensureSelectorOpenForEnqueuing(newChannels, nioSocketChannel); + SocketChannelContext channelContext = nioSocketChannel.getContext(); + newChannels.offer(channelContext); + ensureSelectorOpenForEnqueuing(newChannels, channelContext); wakeup(); } @@ -121,10 +122,9 @@ public class SocketSelector extends ESSelector { */ public void queueWriteInChannelBuffer(WriteOperation writeOperation) { assertOnSelectorThread(); - NioSocketChannel channel = writeOperation.getChannel(); - SocketChannelContext context = channel.getContext(); + SocketChannelContext context = writeOperation.getChannel(); try { - SelectionKeyUtils.setWriteInterested(channel); + SelectionKeyUtils.setWriteInterested(context.getSelectionKey()); context.queueWriteOperation(writeOperation); } catch (Exception e) { executeFailedListener(writeOperation.getListener(), e); @@ -163,19 +163,19 @@ public class SocketSelector extends ESSelector { } } - private void handleWrite(NioSocketChannel nioSocketChannel) { + private void handleWrite(SocketChannelContext context) { try { - eventHandler.handleWrite(nioSocketChannel); + eventHandler.handleWrite(context); } catch (Exception e) { - eventHandler.writeException(nioSocketChannel, e); + eventHandler.writeException(context, e); } } - private void handleRead(NioSocketChannel nioSocketChannel) { + private void handleRead(SocketChannelContext context) { try { - eventHandler.handleRead(nioSocketChannel); + eventHandler.handleRead(context); } catch (Exception e) { - eventHandler.readException(nioSocketChannel, e); + eventHandler.readException(context, e); } } @@ -191,38 +191,34 @@ public class SocketSelector extends ESSelector { } private void setUpNewChannels() { - NioSocketChannel newChannel; - while ((newChannel = this.newChannels.poll()) != null) { - setupChannel(newChannel); + SocketChannelContext channelContext; + while ((channelContext = this.newChannels.poll()) != null) { + setupChannel(channelContext); } } - private void setupChannel(NioSocketChannel newChannel) { - assert newChannel.getSelector() == this : "The channel must be registered with the selector with which it was created"; + private void setupChannel(SocketChannelContext context) { + assert context.getSelector() == this : "The channel must be registered with the selector with which it was created"; try { - if (newChannel.isOpen()) { - newChannel.register(); - SelectionKey key = newChannel.getSelectionKey(); - key.attach(newChannel); - eventHandler.handleRegistration(newChannel); - attemptConnect(newChannel, false); + if (context.isOpen()) { + eventHandler.handleRegistration(context); + attemptConnect(context, false); } else { - eventHandler.registrationException(newChannel, new ClosedChannelException()); + eventHandler.registrationException(context, new ClosedChannelException()); } } catch (Exception e) { - eventHandler.registrationException(newChannel, e); + eventHandler.registrationException(context, e); } } - private void attemptConnect(NioSocketChannel newChannel, boolean connectEvent) { + private void attemptConnect(SocketChannelContext context, boolean connectEvent) { try { - if (newChannel.finishConnect()) { - eventHandler.handleConnect(newChannel); - } else if (connectEvent) { - eventHandler.connectException(newChannel, new IOException("Received OP_CONNECT but connect failed")); + eventHandler.handleConnect(context); + if (connectEvent && context.isConnectComplete() == false) { + eventHandler.connectException(context, new IOException("Received OP_CONNECT but connect failed")); } } catch (Exception e) { - eventHandler.connectException(newChannel, e); + eventHandler.connectException(context, e); } } } diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/WriteOperation.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/WriteOperation.java index d2dfe4f37a0..665b9f7759e 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/WriteOperation.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/WriteOperation.java @@ -30,5 +30,5 @@ public interface WriteOperation { BiConsumer getListener(); - NioSocketChannel getChannel(); + SocketChannelContext getChannel(); } diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/AcceptingSelectorTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/AcceptingSelectorTests.java index 048aa3af8ff..7536ad9d1e1 100644 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/AcceptingSelectorTests.java +++ b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/AcceptingSelectorTests.java @@ -44,6 +44,7 @@ public class AcceptingSelectorTests extends ESTestCase { private AcceptorEventHandler eventHandler; private TestSelectionKey selectionKey; private Selector rawSelector; + private ServerChannelContext context; @Before public void setUp() throws Exception { @@ -56,39 +57,41 @@ public class AcceptingSelectorTests extends ESTestCase { selector = new AcceptingSelector(eventHandler, rawSelector); this.selector.setThread(); + context = mock(ServerChannelContext.class); selectionKey = new TestSelectionKey(0); - selectionKey.attach(serverChannel); - when(serverChannel.getSelectionKey()).thenReturn(selectionKey); - when(serverChannel.getSelector()).thenReturn(selector); - when(serverChannel.isOpen()).thenReturn(true); + selectionKey.attach(context); + when(context.getSelectionKey()).thenReturn(selectionKey); + when(context.getSelector()).thenReturn(selector); + when(context.isOpen()).thenReturn(true); + when(serverChannel.getContext()).thenReturn(context); } - public void testRegisteredChannel() throws IOException, PrivilegedActionException { + public void testRegisteredChannel() throws IOException { selector.scheduleForRegistration(serverChannel); selector.preSelect(); - verify(eventHandler).serverChannelRegistered(serverChannel); + verify(eventHandler).handleRegistration(context); } - public void testClosedChannelWillNotBeRegistered() throws Exception { - when(serverChannel.isOpen()).thenReturn(false); + public void testClosedChannelWillNotBeRegistered() { + when(context.isOpen()).thenReturn(false); selector.scheduleForRegistration(serverChannel); selector.preSelect(); - verify(eventHandler).registrationException(same(serverChannel), any(ClosedChannelException.class)); + verify(eventHandler).registrationException(same(context), any(ClosedChannelException.class)); } public void testRegisterChannelFailsDueToException() throws Exception { selector.scheduleForRegistration(serverChannel); ClosedChannelException closedChannelException = new ClosedChannelException(); - doThrow(closedChannelException).when(serverChannel).register(); + doThrow(closedChannelException).when(eventHandler).handleRegistration(context); selector.preSelect(); - verify(eventHandler).registrationException(serverChannel, closedChannelException); + verify(eventHandler).registrationException(context, closedChannelException); } public void testAcceptEvent() throws IOException { @@ -96,18 +99,18 @@ public class AcceptingSelectorTests extends ESTestCase { selector.processKey(selectionKey); - verify(eventHandler).acceptChannel(serverChannel); + verify(eventHandler).acceptChannel(context); } public void testAcceptException() throws IOException { selectionKey.setReadyOps(SelectionKey.OP_ACCEPT); IOException ioException = new IOException(); - doThrow(ioException).when(eventHandler).acceptChannel(serverChannel); + doThrow(ioException).when(eventHandler).acceptChannel(context); selector.processKey(selectionKey); - verify(eventHandler).acceptException(serverChannel, ioException); + verify(eventHandler).acceptException(context, ioException); } public void testCleanup() throws IOException { @@ -116,11 +119,11 @@ public class AcceptingSelectorTests extends ESTestCase { selector.preSelect(); TestSelectionKey key = new TestSelectionKey(0); - key.attach(serverChannel); + key.attach(context); when(rawSelector.keys()).thenReturn(new HashSet<>(Collections.singletonList(key))); selector.cleanupAndCloseChannels(); - verify(eventHandler).handleClose(serverChannel); + verify(eventHandler).handleClose(context); } } diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/AcceptorEventHandlerTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/AcceptorEventHandlerTests.java index 23ab3bb3e1d..50469b30acd 100644 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/AcceptorEventHandlerTests.java +++ b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/AcceptorEventHandlerTests.java @@ -27,62 +27,89 @@ import java.nio.channels.SelectionKey; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; import java.util.ArrayList; +import java.util.function.Consumer; import static org.mockito.Matchers.same; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class AcceptorEventHandlerTests extends ESTestCase { private AcceptorEventHandler handler; - private SocketSelector socketSelector; private ChannelFactory channelFactory; private NioServerSocketChannel channel; - private ServerChannelContext context; + private DoNotRegisterContext context; + private RoundRobinSupplier selectorSupplier; @Before @SuppressWarnings("unchecked") public void setUpHandler() throws IOException { channelFactory = mock(ChannelFactory.class); - socketSelector = mock(SocketSelector.class); - context = mock(ServerChannelContext.class); ArrayList selectors = new ArrayList<>(); - selectors.add(socketSelector); - handler = new AcceptorEventHandler(logger, new RoundRobinSupplier<>(selectors.toArray(new SocketSelector[selectors.size()]))); + selectors.add(mock(SocketSelector.class)); + selectorSupplier = new RoundRobinSupplier<>(selectors.toArray(new SocketSelector[selectors.size()])); + handler = new AcceptorEventHandler(logger, selectorSupplier); - AcceptingSelector selector = mock(AcceptingSelector.class); - channel = new DoNotRegisterServerChannel(mock(ServerSocketChannel.class), channelFactory, selector); + channel = new NioServerSocketChannel(mock(ServerSocketChannel.class)); + context = new DoNotRegisterContext(channel, mock(AcceptingSelector.class), mock(Consumer.class)); channel.setContext(context); - channel.register(); } - public void testHandleRegisterSetsOP_ACCEPTInterest() { - assertEquals(0, channel.getSelectionKey().interestOps()); + public void testHandleRegisterSetsOP_ACCEPTInterest() throws IOException { + assertNull(context.getSelectionKey()); - handler.serverChannelRegistered(channel); + handler.handleRegistration(context); - assertEquals(SelectionKey.OP_ACCEPT, channel.getSelectionKey().interestOps()); + assertEquals(SelectionKey.OP_ACCEPT, channel.getContext().getSelectionKey().interestOps()); + } + + public void testRegisterAddsAttachment() throws IOException { + assertNull(context.getSelectionKey()); + + handler.handleRegistration(context); + + assertEquals(context, context.getSelectionKey().attachment()); } public void testHandleAcceptCallsChannelFactory() throws IOException { - NioSocketChannel childChannel = new NioSocketChannel(mock(SocketChannel.class), socketSelector); - when(channelFactory.acceptNioChannel(same(channel), same(socketSelector))).thenReturn(childChannel); + NioSocketChannel childChannel = new NioSocketChannel(mock(SocketChannel.class)); + NioSocketChannel nullChannel = null; + when(channelFactory.acceptNioChannel(same(context), same(selectorSupplier))).thenReturn(childChannel, nullChannel); - handler.acceptChannel(channel); - - verify(channelFactory).acceptNioChannel(same(channel), same(socketSelector)); + handler.acceptChannel(context); + verify(channelFactory, times(2)).acceptNioChannel(same(context), same(selectorSupplier)); } @SuppressWarnings("unchecked") public void testHandleAcceptCallsServerAcceptCallback() throws IOException { - NioSocketChannel childChannel = new NioSocketChannel(mock(SocketChannel.class), socketSelector); - childChannel.setContext(mock(SocketChannelContext.class)); - when(channelFactory.acceptNioChannel(same(channel), same(socketSelector))).thenReturn(childChannel); + NioSocketChannel childChannel = new NioSocketChannel(mock(SocketChannel.class)); + SocketChannelContext childContext = mock(SocketChannelContext.class); + childChannel.setContext(childContext); + ServerChannelContext serverChannelContext = mock(ServerChannelContext.class); + channel = new NioServerSocketChannel(mock(ServerSocketChannel.class)); + channel.setContext(serverChannelContext); + when(serverChannelContext.getChannel()).thenReturn(channel); + when(channelFactory.acceptNioChannel(same(context), same(selectorSupplier))).thenReturn(childChannel); - handler.acceptChannel(channel); + handler.acceptChannel(serverChannelContext); - verify(context).acceptChannel(childChannel); + verify(serverChannelContext).acceptChannels(selectorSupplier); + } + + private class DoNotRegisterContext extends ServerChannelContext { + + + @SuppressWarnings("unchecked") + DoNotRegisterContext(NioServerSocketChannel channel, AcceptingSelector selector, Consumer acceptor) { + super(channel, channelFactory, selector, acceptor, mock(Consumer.class)); + } + + @Override + public void register() { + setSelectionKey(new TestSelectionKey(0)); + } } } diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesChannelContextTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesChannelContextTests.java index 68ae1f2e503..d9de0ab1361 100644 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesChannelContextTests.java +++ b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesChannelContextTests.java @@ -27,10 +27,13 @@ import org.mockito.ArgumentCaptor; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.ClosedChannelException; +import java.nio.channels.SocketChannel; import java.util.function.BiConsumer; +import java.util.function.Consumer; import java.util.function.Supplier; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyInt; import static org.mockito.Matchers.isNull; import static org.mockito.Matchers.same; import static org.mockito.Mockito.mock; @@ -42,9 +45,11 @@ public class BytesChannelContextTests extends ESTestCase { private SocketChannelContext.ReadConsumer readConsumer; private NioSocketChannel channel; + private SocketChannel rawChannel; private BytesChannelContext context; private InboundChannelBuffer channelBuffer; private SocketSelector selector; + private Consumer exceptionHandler; private BiConsumer listener; private int messageLength; @@ -57,17 +62,19 @@ public class BytesChannelContextTests extends ESTestCase { selector = mock(SocketSelector.class); listener = mock(BiConsumer.class); channel = mock(NioSocketChannel.class); + rawChannel = mock(SocketChannel.class); channelBuffer = InboundChannelBuffer.allocatingInstance(); - context = new BytesChannelContext(channel, null, readConsumer, channelBuffer); + exceptionHandler = mock(Consumer.class); + when(channel.getRawChannel()).thenReturn(rawChannel); + context = new BytesChannelContext(channel, selector, exceptionHandler, readConsumer, channelBuffer); - when(channel.getSelector()).thenReturn(selector); when(selector.isOnCurrentThread()).thenReturn(true); } public void testSuccessfulRead() throws IOException { byte[] bytes = createMessage(messageLength); - when(channel.read(any(ByteBuffer[].class))).thenAnswer(invocationOnMock -> { + when(rawChannel.read(any(ByteBuffer[].class), anyInt(), anyInt())).thenAnswer(invocationOnMock -> { ByteBuffer[] buffers = (ByteBuffer[]) invocationOnMock.getArguments()[0]; buffers[0].put(bytes); return bytes.length; @@ -85,7 +92,7 @@ public class BytesChannelContextTests extends ESTestCase { public void testMultipleReadsConsumed() throws IOException { byte[] bytes = createMessage(messageLength * 2); - when(channel.read(any(ByteBuffer[].class))).thenAnswer(invocationOnMock -> { + when(rawChannel.read(any(ByteBuffer[].class), anyInt(), anyInt())).thenAnswer(invocationOnMock -> { ByteBuffer[] buffers = (ByteBuffer[]) invocationOnMock.getArguments()[0]; buffers[0].put(bytes); return bytes.length; @@ -103,7 +110,7 @@ public class BytesChannelContextTests extends ESTestCase { public void testPartialRead() throws IOException { byte[] bytes = createMessage(messageLength); - when(channel.read(any(ByteBuffer[].class))).thenAnswer(invocationOnMock -> { + when(rawChannel.read(any(ByteBuffer[].class), anyInt(), anyInt())).thenAnswer(invocationOnMock -> { ByteBuffer[] buffers = (ByteBuffer[]) invocationOnMock.getArguments()[0]; buffers[0].put(bytes); return bytes.length; @@ -128,14 +135,14 @@ public class BytesChannelContextTests extends ESTestCase { public void testReadThrowsIOException() throws IOException { IOException ioException = new IOException(); - when(channel.read(any(ByteBuffer[].class))).thenThrow(ioException); + when(rawChannel.read(any(ByteBuffer[].class), anyInt(), anyInt())).thenThrow(ioException); IOException ex = expectThrows(IOException.class, () -> context.read()); assertSame(ioException, ex); } public void testReadThrowsIOExceptionMeansReadyForClose() throws IOException { - when(channel.read(any(ByteBuffer[].class))).thenThrow(new IOException()); + when(rawChannel.read(any(ByteBuffer[].class), anyInt(), anyInt())).thenThrow(new IOException()); assertFalse(context.selectorShouldClose()); expectThrows(IOException.class, () -> context.read()); @@ -143,22 +150,28 @@ public class BytesChannelContextTests extends ESTestCase { } public void testReadLessThanZeroMeansReadyForClose() throws IOException { - when(channel.read(any(ByteBuffer[].class))).thenReturn(-1); + when(rawChannel.read(any(ByteBuffer[].class), anyInt(), anyInt())).thenReturn(-1L); assertEquals(0, context.read()); assertTrue(context.selectorShouldClose()); } + @SuppressWarnings("unchecked") public void testCloseClosesChannelBuffer() throws IOException { - when(channel.isOpen()).thenReturn(true); - Runnable closer = mock(Runnable.class); - Supplier pageSupplier = () -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14), closer); - InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier); - buffer.ensureCapacity(1); - BytesChannelContext context = new BytesChannelContext(channel, null, readConsumer, buffer); - context.closeFromSelector(); - verify(closer).run(); + try (SocketChannel realChannel = SocketChannel.open()) { + when(channel.getRawChannel()).thenReturn(realChannel); + context = new BytesChannelContext(channel, selector, exceptionHandler, readConsumer, channelBuffer); + + when(channel.isOpen()).thenReturn(true); + Runnable closer = mock(Runnable.class); + Supplier pageSupplier = () -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14), closer); + InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier); + buffer.ensureCapacity(1); + BytesChannelContext context = new BytesChannelContext(channel, selector, exceptionHandler, readConsumer, buffer); + context.closeFromSelector(); + verify(closer).run(); + } } public void testWriteFailsIfClosing() { @@ -182,7 +195,7 @@ public class BytesChannelContextTests extends ESTestCase { BytesWriteOperation writeOp = writeOpCaptor.getValue(); assertSame(listener, writeOp.getListener()); - assertSame(channel, writeOp.getChannel()); + assertSame(context, writeOp.getChannel()); assertEquals(buffers[0], writeOp.getBuffersToWrite()[0]); } @@ -196,7 +209,7 @@ public class BytesChannelContextTests extends ESTestCase { BytesWriteOperation writeOp = writeOpCaptor.getValue(); assertSame(listener, writeOp.getListener()); - assertSame(channel, writeOp.getChannel()); + assertSame(context, writeOp.getChannel()); assertEquals(buffers[0], writeOp.getBuffersToWrite()[0]); } @@ -204,25 +217,31 @@ public class BytesChannelContextTests extends ESTestCase { assertFalse(context.hasQueuedWriteOps()); ByteBuffer[] buffer = {ByteBuffer.allocate(10)}; - context.queueWriteOperation(new BytesWriteOperation(channel, buffer, listener)); + context.queueWriteOperation(new BytesWriteOperation(context, buffer, listener)); assertTrue(context.hasQueuedWriteOps()); } + @SuppressWarnings("unchecked") public void testWriteOpsClearedOnClose() throws Exception { - assertFalse(context.hasQueuedWriteOps()); + try (SocketChannel realChannel = SocketChannel.open()) { + when(channel.getRawChannel()).thenReturn(realChannel); + context = new BytesChannelContext(channel, selector, exceptionHandler, readConsumer, channelBuffer); - ByteBuffer[] buffer = {ByteBuffer.allocate(10)}; - context.queueWriteOperation(new BytesWriteOperation(channel, buffer, listener)); + assertFalse(context.hasQueuedWriteOps()); - assertTrue(context.hasQueuedWriteOps()); + ByteBuffer[] buffer = {ByteBuffer.allocate(10)}; + context.queueWriteOperation(new BytesWriteOperation(context, buffer, listener)); - when(channel.isOpen()).thenReturn(true); - context.closeFromSelector(); + assertTrue(context.hasQueuedWriteOps()); - verify(selector).executeFailedListener(same(listener), any(ClosedChannelException.class)); + when(channel.isOpen()).thenReturn(true); + context.closeFromSelector(); - assertFalse(context.hasQueuedWriteOps()); + verify(selector).executeFailedListener(same(listener), any(ClosedChannelException.class)); + + assertFalse(context.hasQueuedWriteOps()); + } } public void testQueuedWriteIsFlushedInFlushCall() throws Exception { @@ -239,7 +258,7 @@ public class BytesChannelContextTests extends ESTestCase { when(writeOperation.getListener()).thenReturn(listener); context.flushChannel(); - verify(channel).write(buffers); + verify(rawChannel).write(buffers, 0, buffers.length); verify(selector).executeListener(listener, null); assertFalse(context.hasQueuedWriteOps()); } @@ -253,6 +272,7 @@ public class BytesChannelContextTests extends ESTestCase { assertTrue(context.hasQueuedWriteOps()); when(writeOperation.isFullyFlushed()).thenReturn(false); + when(writeOperation.getBuffersToWrite()).thenReturn(new ByteBuffer[0]); context.flushChannel(); verify(listener, times(0)).accept(null, null); @@ -266,6 +286,8 @@ public class BytesChannelContextTests extends ESTestCase { BiConsumer listener2 = mock(BiConsumer.class); BytesWriteOperation writeOperation1 = mock(BytesWriteOperation.class); BytesWriteOperation writeOperation2 = mock(BytesWriteOperation.class); + when(writeOperation1.getBuffersToWrite()).thenReturn(new ByteBuffer[0]); + when(writeOperation2.getBuffersToWrite()).thenReturn(new ByteBuffer[0]); when(writeOperation1.getListener()).thenReturn(listener); when(writeOperation2.getListener()).thenReturn(listener2); context.queueWriteOperation(writeOperation1); @@ -300,7 +322,7 @@ public class BytesChannelContextTests extends ESTestCase { IOException exception = new IOException(); when(writeOperation.getBuffersToWrite()).thenReturn(buffers); - when(channel.write(buffers)).thenThrow(exception); + when(rawChannel.write(buffers, 0, buffers.length)).thenThrow(exception); when(writeOperation.getListener()).thenReturn(listener); expectThrows(IOException.class, () -> context.flushChannel()); @@ -315,7 +337,7 @@ public class BytesChannelContextTests extends ESTestCase { IOException exception = new IOException(); when(writeOperation.getBuffersToWrite()).thenReturn(buffers); - when(channel.write(buffers)).thenThrow(exception); + when(rawChannel.write(buffers, 0, buffers.length)).thenThrow(exception); assertFalse(context.selectorShouldClose()); expectThrows(IOException.class, () -> context.flushChannel()); diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/WriteOperationTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesWriteOperationTests.java similarity index 91% rename from libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/WriteOperationTests.java rename to libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesWriteOperationTests.java index 59fb9cde438..05afc80a490 100644 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/WriteOperationTests.java +++ b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesWriteOperationTests.java @@ -25,29 +25,26 @@ import org.mockito.ArgumentCaptor; import java.io.IOException; import java.nio.ByteBuffer; -import java.util.List; import java.util.function.BiConsumer; -import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; -public class WriteOperationTests extends ESTestCase { +public class BytesWriteOperationTests extends ESTestCase { - private NioSocketChannel channel; + private SocketChannelContext channelContext; private BiConsumer listener; @Before @SuppressWarnings("unchecked") public void setFields() { - channel = mock(NioSocketChannel.class); + channelContext = mock(SocketChannelContext.class); listener = mock(BiConsumer.class); } public void testFullyFlushedMarker() { ByteBuffer[] buffers = {ByteBuffer.allocate(10)}; - BytesWriteOperation writeOp = new BytesWriteOperation(channel, buffers, listener); + BytesWriteOperation writeOp = new BytesWriteOperation(channelContext, buffers, listener); writeOp.incrementIndex(10); @@ -56,7 +53,7 @@ public class WriteOperationTests extends ESTestCase { public void testPartiallyFlushedMarker() { ByteBuffer[] buffers = {ByteBuffer.allocate(10)}; - BytesWriteOperation writeOp = new BytesWriteOperation(channel, buffers, listener); + BytesWriteOperation writeOp = new BytesWriteOperation(channelContext, buffers, listener); writeOp.incrementIndex(5); @@ -65,7 +62,7 @@ public class WriteOperationTests extends ESTestCase { public void testMultipleFlushesWithCompositeBuffer() throws IOException { ByteBuffer[] buffers = {ByteBuffer.allocate(10), ByteBuffer.allocate(15), ByteBuffer.allocate(3)}; - BytesWriteOperation writeOp = new BytesWriteOperation(channel, buffers, listener); + BytesWriteOperation writeOp = new BytesWriteOperation(channelContext, buffers, listener); ArgumentCaptor buffersCaptor = ArgumentCaptor.forClass(ByteBuffer[].class); diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/ChannelContextTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/ChannelContextTests.java new file mode 100644 index 00000000000..f262dd06330 --- /dev/null +++ b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/ChannelContextTests.java @@ -0,0 +1,214 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.nio; + +import org.elasticsearch.test.ESTestCase; +import org.junit.Before; + +import java.io.IOException; +import java.net.SocketAddress; +import java.net.SocketOption; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.NetworkChannel; +import java.nio.channels.SelectableChannel; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.spi.SelectorProvider; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +public class ChannelContextTests extends ESTestCase { + + private TestChannelContext context; + private Consumer exceptionHandler; + + @SuppressWarnings("unchecked") + @Before + public void setup() throws Exception { + super.setUp(); + exceptionHandler = mock(Consumer.class); + } + + public void testCloseSuccess() throws IOException { + FakeRawChannel rawChannel = new FakeRawChannel(null); + context = new TestChannelContext(rawChannel, exceptionHandler); + + AtomicBoolean listenerCalled = new AtomicBoolean(false); + context.addCloseListener((v, t) -> { + if (t == null) { + listenerCalled.compareAndSet(false, true); + } else { + throw new AssertionError("Close should not fail"); + } + }); + + assertFalse(rawChannel.hasCloseBeenCalled()); + assertTrue(context.isOpen()); + assertFalse(listenerCalled.get()); + context.closeFromSelector(); + assertTrue(rawChannel.hasCloseBeenCalled()); + assertFalse(context.isOpen()); + assertTrue(listenerCalled.get()); + } + + public void testCloseException() throws IOException { + IOException ioException = new IOException("boom"); + FakeRawChannel rawChannel = new FakeRawChannel(ioException); + context = new TestChannelContext(rawChannel, exceptionHandler); + + AtomicReference exception = new AtomicReference<>(); + context.addCloseListener((v, t) -> { + if (t == null) { + throw new AssertionError("Close should not fail"); + } else { + exception.set((Exception) t); + } + }); + + assertFalse(rawChannel.hasCloseBeenCalled()); + assertTrue(context.isOpen()); + assertNull(exception.get()); + expectThrows(IOException.class, context::closeFromSelector); + assertTrue(rawChannel.hasCloseBeenCalled()); + assertFalse(context.isOpen()); + assertSame(ioException, exception.get()); + } + + public void testExceptionsAreDelegatedToHandler() { + context = new TestChannelContext(new FakeRawChannel(null), exceptionHandler); + IOException exception = new IOException(); + context.handleException(exception); + verify(exceptionHandler).accept(exception); + } + + private static class TestChannelContext extends ChannelContext { + + private TestChannelContext(FakeRawChannel channel, Consumer exceptionHandler) { + super(channel, exceptionHandler); + } + + @Override + public void closeChannel() { + throw new UnsupportedOperationException("not implemented"); + } + + @Override + public ESSelector getSelector() { + throw new UnsupportedOperationException("not implemented"); + } + + @Override + public NioChannel getChannel() { + throw new UnsupportedOperationException("not implemented"); + } + } + + private class FakeRawChannel extends SelectableChannel implements NetworkChannel { + + private final IOException exceptionOnClose; + private AtomicBoolean hasCloseBeenCalled = new AtomicBoolean(false); + + private FakeRawChannel(IOException exceptionOnClose) { + this.exceptionOnClose = exceptionOnClose; + } + + @Override + protected void implCloseChannel() throws IOException { + hasCloseBeenCalled.compareAndSet(false, true); + if (exceptionOnClose != null) { + throw exceptionOnClose; + } + } + + private boolean hasCloseBeenCalled() { + return hasCloseBeenCalled.get(); + } + + @Override + public NetworkChannel bind(SocketAddress local) throws IOException { + throw new UnsupportedOperationException("not implemented"); + } + + @Override + public SocketAddress getLocalAddress() throws IOException { + throw new UnsupportedOperationException("not implemented"); + } + + @Override + public NetworkChannel setOption(SocketOption name, T value) throws IOException { + throw new UnsupportedOperationException("not implemented"); + } + + @Override + public T getOption(SocketOption name) throws IOException { + throw new UnsupportedOperationException("not implemented"); + } + + @Override + public Set> supportedOptions() { + throw new UnsupportedOperationException("not implemented"); + } + + @Override + public SelectorProvider provider() { + throw new UnsupportedOperationException("not implemented"); + } + + @Override + public int validOps() { + throw new UnsupportedOperationException("not implemented"); + } + + @Override + public boolean isRegistered() { + throw new UnsupportedOperationException("not implemented"); + } + + @Override + public SelectionKey keyFor(Selector sel) { + throw new UnsupportedOperationException("not implemented"); + } + + @Override + public SelectionKey register(Selector sel, int ops, Object att) throws ClosedChannelException { + throw new UnsupportedOperationException("not implemented"); + } + + @Override + public SelectableChannel configureBlocking(boolean block) throws IOException { + throw new UnsupportedOperationException("not implemented"); + } + + @Override + public boolean isBlocking() { + throw new UnsupportedOperationException("not implemented"); + } + + @Override + public Object blockingLock() { + throw new UnsupportedOperationException("not implemented"); + } + } +} diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/ChannelFactoryTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/ChannelFactoryTests.java index 1c8a8a130cc..99880a2fd80 100644 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/ChannelFactoryTests.java +++ b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/ChannelFactoryTests.java @@ -28,6 +28,7 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; +import java.util.function.Supplier; import static org.mockito.Matchers.any; import static org.mockito.Matchers.same; @@ -43,6 +44,8 @@ public class ChannelFactoryTests extends ESTestCase { private SocketChannel rawChannel; private ServerSocketChannel rawServerChannel; private SocketSelector socketSelector; + private Supplier socketSelectorSupplier; + private Supplier acceptingSelectorSupplier; private AcceptingSelector acceptingSelector; @Before @@ -52,34 +55,36 @@ public class ChannelFactoryTests extends ESTestCase { channelFactory = new TestChannelFactory(rawChannelFactory); socketSelector = mock(SocketSelector.class); acceptingSelector = mock(AcceptingSelector.class); + socketSelectorSupplier = mock(Supplier.class); + acceptingSelectorSupplier = mock(Supplier.class); rawChannel = SocketChannel.open(); rawServerChannel = ServerSocketChannel.open(); + when(socketSelectorSupplier.get()).thenReturn(socketSelector); + when(acceptingSelectorSupplier.get()).thenReturn(acceptingSelector); } @After public void ensureClosed() throws IOException { - IOUtils.closeWhileHandlingException(rawChannel); - IOUtils.closeWhileHandlingException(rawServerChannel); + IOUtils.closeWhileHandlingException(rawChannel, rawServerChannel); } public void testAcceptChannel() throws IOException { - NioServerSocketChannel serverChannel = mock(NioServerSocketChannel.class); - when(rawChannelFactory.acceptNioChannel(serverChannel)).thenReturn(rawChannel); + ServerChannelContext serverChannelContext = mock(ServerChannelContext.class); + when(rawChannelFactory.acceptNioChannel(serverChannelContext)).thenReturn(rawChannel); - NioSocketChannel channel = channelFactory.acceptNioChannel(serverChannel, socketSelector); + NioSocketChannel channel = channelFactory.acceptNioChannel(serverChannelContext, socketSelectorSupplier); verify(socketSelector).scheduleForRegistration(channel); - assertEquals(socketSelector, channel.getSelector()); assertEquals(rawChannel, channel.getRawChannel()); } public void testAcceptedChannelRejected() throws IOException { - NioServerSocketChannel serverChannel = mock(NioServerSocketChannel.class); - when(rawChannelFactory.acceptNioChannel(serverChannel)).thenReturn(rawChannel); + ServerChannelContext serverChannelContext = mock(ServerChannelContext.class); + when(rawChannelFactory.acceptNioChannel(serverChannelContext)).thenReturn(rawChannel); doThrow(new IllegalStateException()).when(socketSelector).scheduleForRegistration(any()); - expectThrows(IllegalStateException.class, () -> channelFactory.acceptNioChannel(serverChannel, socketSelector)); + expectThrows(IllegalStateException.class, () -> channelFactory.acceptNioChannel(serverChannelContext, socketSelectorSupplier)); assertFalse(rawChannel.isOpen()); } @@ -88,11 +93,10 @@ public class ChannelFactoryTests extends ESTestCase { InetSocketAddress address = mock(InetSocketAddress.class); when(rawChannelFactory.openNioChannel(same(address))).thenReturn(rawChannel); - NioSocketChannel channel = channelFactory.openNioChannel(address, socketSelector); + NioSocketChannel channel = channelFactory.openNioChannel(address, socketSelectorSupplier); verify(socketSelector).scheduleForRegistration(channel); - assertEquals(socketSelector, channel.getSelector()); assertEquals(rawChannel, channel.getRawChannel()); } @@ -101,7 +105,7 @@ public class ChannelFactoryTests extends ESTestCase { when(rawChannelFactory.openNioChannel(same(address))).thenReturn(rawChannel); doThrow(new IllegalStateException()).when(socketSelector).scheduleForRegistration(any()); - expectThrows(IllegalStateException.class, () -> channelFactory.openNioChannel(address, socketSelector)); + expectThrows(IllegalStateException.class, () -> channelFactory.openNioChannel(address, socketSelectorSupplier)); assertFalse(rawChannel.isOpen()); } @@ -110,11 +114,10 @@ public class ChannelFactoryTests extends ESTestCase { InetSocketAddress address = mock(InetSocketAddress.class); when(rawChannelFactory.openNioServerSocketChannel(same(address))).thenReturn(rawServerChannel); - NioServerSocketChannel channel = channelFactory.openNioServerSocketChannel(address, acceptingSelector); + NioServerSocketChannel channel = channelFactory.openNioServerSocketChannel(address, acceptingSelectorSupplier); verify(acceptingSelector).scheduleForRegistration(channel); - assertEquals(acceptingSelector, channel.getSelector()); assertEquals(rawServerChannel, channel.getRawChannel()); } @@ -123,7 +126,7 @@ public class ChannelFactoryTests extends ESTestCase { when(rawChannelFactory.openNioServerSocketChannel(same(address))).thenReturn(rawServerChannel); doThrow(new IllegalStateException()).when(acceptingSelector).scheduleForRegistration(any()); - expectThrows(IllegalStateException.class, () -> channelFactory.openNioServerSocketChannel(address, acceptingSelector)); + expectThrows(IllegalStateException.class, () -> channelFactory.openNioServerSocketChannel(address, acceptingSelectorSupplier)); assertFalse(rawServerChannel.isOpen()); } @@ -137,14 +140,14 @@ public class ChannelFactoryTests extends ESTestCase { @SuppressWarnings("unchecked") @Override public NioSocketChannel createChannel(SocketSelector selector, SocketChannel channel) throws IOException { - NioSocketChannel nioSocketChannel = new NioSocketChannel(channel, selector); + NioSocketChannel nioSocketChannel = new NioSocketChannel(channel); nioSocketChannel.setContext(mock(SocketChannelContext.class)); return nioSocketChannel; } @Override public NioServerSocketChannel createServerChannel(AcceptingSelector selector, ServerSocketChannel channel) throws IOException { - return new NioServerSocketChannel(channel, this, selector); + return new NioServerSocketChannel(channel); } } } diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/DoNotRegisterChannel.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/DoNotRegisterChannel.java deleted file mode 100644 index dd73d43292a..00000000000 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/DoNotRegisterChannel.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.nio; - -import java.io.IOException; -import java.nio.channels.ClosedChannelException; -import java.nio.channels.SocketChannel; - -public class DoNotRegisterChannel extends NioSocketChannel { - - public DoNotRegisterChannel(SocketChannel socketChannel, SocketSelector selector) throws IOException { - super(socketChannel, selector); - } - - @Override - public void register() throws ClosedChannelException { - setSelectionKey(new TestSelectionKey(0)); - } -} diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/DoNotRegisterServerChannel.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/DoNotRegisterServerChannel.java deleted file mode 100644 index 1d5e605c444..00000000000 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/DoNotRegisterServerChannel.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.nio; - -import java.io.IOException; -import java.nio.channels.ClosedChannelException; -import java.nio.channels.ServerSocketChannel; - -public class DoNotRegisterServerChannel extends NioServerSocketChannel { - - public DoNotRegisterServerChannel(ServerSocketChannel channel, ChannelFactory channelFactory, AcceptingSelector selector) - throws IOException { - super(channel, channelFactory, selector); - } - - @Override - public void register() throws ClosedChannelException { - setSelectionKey(new TestSelectionKey(0)); - } -} diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/ESSelectorTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/ESSelectorTests.java index 1dab05487a1..cb8f0757fb9 100644 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/ESSelectorTests.java +++ b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/ESSelectorTests.java @@ -27,6 +27,7 @@ import java.nio.channels.CancelledKeyException; import java.nio.channels.ClosedSelectorException; import java.nio.channels.SelectionKey; import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; import static org.mockito.Matchers.anyInt; import static org.mockito.Mockito.mock; @@ -47,15 +48,18 @@ public class ESSelectorTests extends ESTestCase { selector = new TestSelector(handler, rawSelector); } + @SuppressWarnings({"unchecked", "rawtypes"}) public void testQueueChannelForClosed() throws IOException { NioChannel channel = mock(NioChannel.class); - when(channel.getSelector()).thenReturn(selector); + ChannelContext context = mock(ChannelContext.class); + when(channel.getContext()).thenReturn(context); + when(context.getSelector()).thenReturn(selector); selector.queueChannelClose(channel); selector.singleLoop(); - verify(handler).handleClose(channel); + verify(handler).handleClose(context); } public void testSelectorClosedExceptionIsNotCaughtWhileRunning() throws IOException { diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/NioServerSocketChannelTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/NioServerSocketChannelTests.java deleted file mode 100644 index 12a77a425eb..00000000000 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/NioServerSocketChannelTests.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.nio; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.support.PlainActionFuture; -import org.elasticsearch.common.util.concurrent.FutureUtils; -import org.elasticsearch.test.ESTestCase; -import org.junit.After; -import org.junit.Before; - -import java.io.IOException; -import java.nio.channels.ServerSocketChannel; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.BiConsumer; -import java.util.function.Consumer; -import java.util.function.Supplier; - -import static org.mockito.Mockito.mock; - -public class NioServerSocketChannelTests extends ESTestCase { - - private AcceptingSelector selector; - private AtomicBoolean closedRawChannel; - private Thread thread; - - @Before - @SuppressWarnings("unchecked") - public void setSelector() throws IOException { - selector = new AcceptingSelector(new AcceptorEventHandler(logger, mock(Supplier.class))); - thread = new Thread(selector::runLoop); - closedRawChannel = new AtomicBoolean(false); - thread.start(); - FutureUtils.get(selector.isRunningFuture()); - } - - @After - public void stopSelector() throws IOException, InterruptedException { - selector.close(); - thread.join(); - } - - @SuppressWarnings("unchecked") - public void testClose() throws Exception { - AtomicBoolean isClosed = new AtomicBoolean(false); - CountDownLatch latch = new CountDownLatch(1); - - try (ServerSocketChannel rawChannel = ServerSocketChannel.open()) { - NioServerSocketChannel channel = new NioServerSocketChannel(rawChannel, mock(ChannelFactory.class), selector); - channel.setContext(new ServerChannelContext(channel, mock(Consumer.class), mock(BiConsumer.class))); - channel.addCloseListener(ActionListener.toBiConsumer(new ActionListener() { - @Override - public void onResponse(Void o) { - isClosed.set(true); - latch.countDown(); - } - - @Override - public void onFailure(Exception e) { - isClosed.set(true); - latch.countDown(); - } - })); - - assertTrue(channel.isOpen()); - assertTrue(rawChannel.isOpen()); - assertFalse(isClosed.get()); - - PlainActionFuture closeFuture = PlainActionFuture.newFuture(); - channel.addCloseListener(ActionListener.toBiConsumer(closeFuture)); - selector.queueChannelClose(channel); - closeFuture.actionGet(); - - - assertFalse(rawChannel.isOpen()); - assertFalse(channel.isOpen()); - latch.await(); - assertTrue(isClosed.get()); - } - } -} diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/NioSocketChannelTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/NioSocketChannelTests.java deleted file mode 100644 index bbda9233bbb..00000000000 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/NioSocketChannelTests.java +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.nio; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.support.PlainActionFuture; -import org.elasticsearch.common.util.concurrent.FutureUtils; -import org.elasticsearch.test.ESTestCase; -import org.junit.After; -import org.junit.Before; - -import java.io.IOException; -import java.net.ConnectException; -import java.nio.channels.SocketChannel; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.BiConsumer; - -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -public class NioSocketChannelTests extends ESTestCase { - - private SocketSelector selector; - private Thread thread; - - @Before - @SuppressWarnings("unchecked") - public void startSelector() throws IOException { - selector = new SocketSelector(new SocketEventHandler(logger)); - thread = new Thread(selector::runLoop); - thread.start(); - FutureUtils.get(selector.isRunningFuture()); - } - - @After - public void stopSelector() throws IOException, InterruptedException { - selector.close(); - thread.join(); - } - - @SuppressWarnings("unchecked") - public void testClose() throws Exception { - AtomicBoolean isClosed = new AtomicBoolean(false); - CountDownLatch latch = new CountDownLatch(1); - - try(SocketChannel rawChannel = SocketChannel.open()) { - NioSocketChannel socketChannel = new NioSocketChannel(rawChannel, selector); - socketChannel.setContext(new BytesChannelContext(socketChannel, mock(BiConsumer.class), - mock(SocketChannelContext.ReadConsumer.class), InboundChannelBuffer.allocatingInstance())); - socketChannel.addCloseListener(ActionListener.toBiConsumer(new ActionListener() { - @Override - public void onResponse(Void o) { - isClosed.set(true); - latch.countDown(); - } - - @Override - public void onFailure(Exception e) { - isClosed.set(true); - latch.countDown(); - } - })); - - assertTrue(socketChannel.isOpen()); - assertTrue(rawChannel.isOpen()); - assertFalse(isClosed.get()); - - PlainActionFuture closeFuture = PlainActionFuture.newFuture(); - socketChannel.addCloseListener(ActionListener.toBiConsumer(closeFuture)); - selector.queueChannelClose(socketChannel); - closeFuture.actionGet(); - - assertFalse(rawChannel.isOpen()); - assertFalse(socketChannel.isOpen()); - latch.await(); - assertTrue(isClosed.get()); - } - } - - @SuppressWarnings("unchecked") - public void testConnectSucceeds() throws Exception { - SocketChannel rawChannel = mock(SocketChannel.class); - when(rawChannel.finishConnect()).thenReturn(true); - NioSocketChannel socketChannel = new DoNotRegisterChannel(rawChannel, selector); - socketChannel.setContext(mock(SocketChannelContext.class)); - selector.scheduleForRegistration(socketChannel); - - PlainActionFuture connectFuture = PlainActionFuture.newFuture(); - socketChannel.addConnectListener(ActionListener.toBiConsumer(connectFuture)); - connectFuture.get(100, TimeUnit.SECONDS); - - assertTrue(socketChannel.isConnectComplete()); - assertTrue(socketChannel.isOpen()); - } - - @SuppressWarnings("unchecked") - public void testConnectFails() throws Exception { - SocketChannel rawChannel = mock(SocketChannel.class); - when(rawChannel.finishConnect()).thenThrow(new ConnectException()); - NioSocketChannel socketChannel = new DoNotRegisterChannel(rawChannel, selector); - socketChannel.setContext(mock(SocketChannelContext.class)); - selector.scheduleForRegistration(socketChannel); - - PlainActionFuture connectFuture = PlainActionFuture.newFuture(); - socketChannel.addConnectListener(ActionListener.toBiConsumer(connectFuture)); - ExecutionException e = expectThrows(ExecutionException.class, () -> connectFuture.get(100, TimeUnit.SECONDS)); - assertTrue(e.getCause() instanceof IOException); - - assertFalse(socketChannel.isConnectComplete()); - // Even if connection fails the channel is 'open' until close() is called - assertTrue(socketChannel.isOpen()); - } -} diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java new file mode 100644 index 00000000000..17e6b7acba2 --- /dev/null +++ b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java @@ -0,0 +1,173 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.nio; + +import org.elasticsearch.test.ESTestCase; +import org.junit.Before; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.SocketChannel; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyInt; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class SocketChannelContextTests extends ESTestCase { + + private SocketChannel rawChannel; + private TestSocketChannelContext context; + private Consumer exceptionHandler; + private NioSocketChannel channel; + + @SuppressWarnings("unchecked") + @Before + public void setup() throws Exception { + super.setUp(); + + rawChannel = mock(SocketChannel.class); + channel = mock(NioSocketChannel.class); + when(channel.getRawChannel()).thenReturn(rawChannel); + exceptionHandler = mock(Consumer.class); + context = new TestSocketChannelContext(channel, mock(SocketSelector.class), exceptionHandler); + } + + public void testIOExceptionSetIfEncountered() throws IOException { + when(rawChannel.write(any(ByteBuffer[].class), anyInt(), anyInt())).thenThrow(new IOException()); + when(rawChannel.write(any(ByteBuffer.class))).thenThrow(new IOException()); + when(rawChannel.read(any(ByteBuffer[].class), anyInt(), anyInt())).thenThrow(new IOException()); + when(rawChannel.read(any(ByteBuffer.class))).thenThrow(new IOException()); + assertFalse(context.hasIOException()); + expectThrows(IOException.class, () -> { + if (randomBoolean()) { + context.read(); + } else { + context.flushChannel(); + } + }); + assertTrue(context.hasIOException()); + } + + public void testSignalWhenPeerClosed() throws IOException { + when(rawChannel.read(any(ByteBuffer[].class), anyInt(), anyInt())).thenReturn(-1L); + when(rawChannel.read(any(ByteBuffer.class))).thenReturn(-1); + assertFalse(context.isPeerClosed()); + context.read(); + assertTrue(context.isPeerClosed()); + } + + public void testConnectSucceeds() throws IOException { + AtomicBoolean listenerCalled = new AtomicBoolean(false); + when(rawChannel.finishConnect()).thenReturn(false, true); + + context.addConnectListener((v, t) -> { + if (t == null) { + listenerCalled.compareAndSet(false, true); + } else { + throw new AssertionError("Connection should not fail"); + } + }); + + assertFalse(context.connect()); + assertFalse(context.isConnectComplete()); + assertFalse(listenerCalled.get()); + assertTrue(context.connect()); + assertTrue(context.isConnectComplete()); + assertTrue(listenerCalled.get()); + } + + public void testConnectFails() throws IOException { + AtomicReference exception = new AtomicReference<>(); + IOException ioException = new IOException("boom"); + when(rawChannel.finishConnect()).thenReturn(false).thenThrow(ioException); + + context.addConnectListener((v, t) -> { + if (t == null) { + throw new AssertionError("Connection should not succeed"); + } else { + exception.set((Exception) t); + } + }); + + assertFalse(context.connect()); + assertFalse(context.isConnectComplete()); + assertNull(exception.get()); + expectThrows(IOException.class, context::connect); + assertFalse(context.isConnectComplete()); + assertSame(ioException, exception.get()); + } + + private static class TestSocketChannelContext extends SocketChannelContext { + + private TestSocketChannelContext(NioSocketChannel channel, SocketSelector selector, Consumer exceptionHandler) { + super(channel, selector, exceptionHandler); + } + + @Override + public int read() throws IOException { + if (randomBoolean()) { + ByteBuffer[] byteBuffers = {ByteBuffer.allocate(10)}; + return readFromChannel(byteBuffers); + } else { + return readFromChannel(ByteBuffer.allocate(10)); + } + } + + @Override + public void sendMessage(ByteBuffer[] buffers, BiConsumer listener) { + + } + + @Override + public void queueWriteOperation(WriteOperation writeOperation) { + + } + + @Override + public void flushChannel() throws IOException { + if (randomBoolean()) { + ByteBuffer[] byteBuffers = {ByteBuffer.allocate(10)}; + flushToChannel(byteBuffers); + } else { + flushToChannel(ByteBuffer.allocate(10)); + } + } + + @Override + public boolean hasQueuedWriteOps() { + return false; + } + + @Override + public boolean selectorShouldClose() { + return false; + } + + @Override + public void closeChannel() { + + } + } +} diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketEventHandlerTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketEventHandlerTests.java index d74214636db..4f476c1ff6b 100644 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketEventHandlerTests.java +++ b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketEventHandlerTests.java @@ -23,12 +23,10 @@ import org.elasticsearch.test.ESTestCase; import org.junit.Before; import java.io.IOException; -import java.nio.ByteBuffer; import java.nio.channels.CancelledKeyException; import java.nio.channels.SelectionKey; import java.nio.channels.SocketChannel; -import java.util.function.BiConsumer; -import java.util.function.Supplier; +import java.util.function.Consumer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -37,137 +35,171 @@ import static org.mockito.Mockito.when; public class SocketEventHandlerTests extends ESTestCase { - private BiConsumer exceptionHandler; + private Consumer exceptionHandler; private SocketEventHandler handler; private NioSocketChannel channel; private SocketChannel rawChannel; + private DoNotRegisterContext context; @Before @SuppressWarnings("unchecked") public void setUpHandler() throws IOException { - exceptionHandler = mock(BiConsumer.class); - SocketSelector socketSelector = mock(SocketSelector.class); + exceptionHandler = mock(Consumer.class); + SocketSelector selector = mock(SocketSelector.class); handler = new SocketEventHandler(logger); rawChannel = mock(SocketChannel.class); - channel = new DoNotRegisterChannel(rawChannel, socketSelector); + channel = new NioSocketChannel(rawChannel); when(rawChannel.finishConnect()).thenReturn(true); - InboundChannelBuffer buffer = InboundChannelBuffer.allocatingInstance(); - channel.setContext(new BytesChannelContext(channel, exceptionHandler, mock(SocketChannelContext.ReadConsumer.class), buffer)); - channel.register(); - channel.finishConnect(); + context = new DoNotRegisterContext(channel, selector, exceptionHandler, new TestSelectionKey(0)); + channel.setContext(context); + handler.handleRegistration(context); - when(socketSelector.isOnCurrentThread()).thenReturn(true); + when(selector.isOnCurrentThread()).thenReturn(true); } public void testRegisterCallsContext() throws IOException { NioSocketChannel channel = mock(NioSocketChannel.class); SocketChannelContext channelContext = mock(SocketChannelContext.class); when(channel.getContext()).thenReturn(channelContext); - when(channel.getSelectionKey()).thenReturn(new TestSelectionKey(0)); - handler.handleRegistration(channel); - verify(channelContext).channelRegistered(); + when(channelContext.getSelectionKey()).thenReturn(new TestSelectionKey(0)); + handler.handleRegistration(channelContext); + verify(channelContext).register(); } public void testRegisterAddsOP_CONNECTAndOP_READInterest() throws IOException { - handler.handleRegistration(channel); - assertEquals(SelectionKey.OP_READ | SelectionKey.OP_CONNECT, channel.getSelectionKey().interestOps()); + SocketChannelContext context = mock(SocketChannelContext.class); + when(context.getSelectionKey()).thenReturn(new TestSelectionKey(0)); + handler.handleRegistration(context); + assertEquals(SelectionKey.OP_READ | SelectionKey.OP_CONNECT, context.getSelectionKey().interestOps()); + } + + public void testRegisterAddsAttachment() throws IOException { + SocketChannelContext context = mock(SocketChannelContext.class); + when(context.getSelectionKey()).thenReturn(new TestSelectionKey(0)); + handler.handleRegistration(context); + assertEquals(context, context.getSelectionKey().attachment()); } public void testRegisterWithPendingWritesAddsOP_CONNECTAndOP_READAndOP_WRITEInterest() throws IOException { channel.getContext().queueWriteOperation(mock(BytesWriteOperation.class)); - handler.handleRegistration(channel); - assertEquals(SelectionKey.OP_READ | SelectionKey.OP_CONNECT | SelectionKey.OP_WRITE, channel.getSelectionKey().interestOps()); + handler.handleRegistration(context); + assertEquals(SelectionKey.OP_READ | SelectionKey.OP_CONNECT | SelectionKey.OP_WRITE, context.getSelectionKey().interestOps()); } public void testRegistrationExceptionCallsExceptionHandler() throws IOException { CancelledKeyException exception = new CancelledKeyException(); - handler.registrationException(channel, exception); - verify(exceptionHandler).accept(channel, exception); + handler.registrationException(context, exception); + verify(exceptionHandler).accept(exception); } - public void testConnectRemovesOP_CONNECTInterest() throws IOException { - SelectionKeyUtils.setConnectAndReadInterested(channel); - handler.handleConnect(channel); - assertEquals(SelectionKey.OP_READ, channel.getSelectionKey().interestOps()); + public void testConnectDoesNotRemoveOP_CONNECTInterestIfIncomplete() throws IOException { + SelectionKeyUtils.setConnectAndReadInterested(context.getSelectionKey()); + handler.handleConnect(context); + assertEquals(SelectionKey.OP_READ, context.getSelectionKey().interestOps()); + } + + public void testConnectRemovesOP_CONNECTInterestIfComplete() throws IOException { + SelectionKeyUtils.setConnectAndReadInterested(context.getSelectionKey()); + handler.handleConnect(context); + assertEquals(SelectionKey.OP_READ, context.getSelectionKey().interestOps()); } public void testConnectExceptionCallsExceptionHandler() throws IOException { IOException exception = new IOException(); - handler.connectException(channel, exception); - verify(exceptionHandler).accept(channel, exception); + handler.connectException(context, exception); + verify(exceptionHandler).accept(exception); } public void testHandleReadDelegatesToContext() throws IOException { - NioSocketChannel channel = new DoNotRegisterChannel(rawChannel, mock(SocketSelector.class)); + NioSocketChannel channel = new NioSocketChannel(rawChannel); SocketChannelContext context = mock(SocketChannelContext.class); channel.setContext(context); when(context.read()).thenReturn(1); - handler.handleRead(channel); + handler.handleRead(context); verify(context).read(); } public void testReadExceptionCallsExceptionHandler() { IOException exception = new IOException(); - handler.readException(channel, exception); - verify(exceptionHandler).accept(channel, exception); + handler.readException(context, exception); + verify(exceptionHandler).accept(exception); } public void testWriteExceptionCallsExceptionHandler() { IOException exception = new IOException(); - handler.writeException(channel, exception); - verify(exceptionHandler).accept(channel, exception); + handler.writeException(context, exception); + verify(exceptionHandler).accept(exception); } public void testPostHandlingCallWillCloseTheChannelIfReady() throws IOException { NioSocketChannel channel = mock(NioSocketChannel.class); SocketChannelContext context = mock(SocketChannelContext.class); - when(channel.getSelectionKey()).thenReturn(new TestSelectionKey(0)); when(channel.getContext()).thenReturn(context); when(context.selectorShouldClose()).thenReturn(true); - handler.postHandling(channel); + handler.postHandling(context); verify(context).closeFromSelector(); } public void testPostHandlingCallWillNotCloseTheChannelIfNotReady() throws IOException { - NioSocketChannel channel = mock(NioSocketChannel.class); SocketChannelContext context = mock(SocketChannelContext.class); - when(channel.getSelectionKey()).thenReturn(new TestSelectionKey(0)); - - when(channel.getContext()).thenReturn(context); + when(context.getSelectionKey()).thenReturn(new TestSelectionKey(SelectionKey.OP_READ | SelectionKey.OP_WRITE)); when(context.selectorShouldClose()).thenReturn(false); - handler.postHandling(channel); - verify(channel, times(0)).closeFromSelector(); + NioSocketChannel channel = mock(NioSocketChannel.class); + when(channel.getContext()).thenReturn(context); + + handler.postHandling(context); + + verify(context, times(0)).closeFromSelector(); } public void testPostHandlingWillAddWriteIfNecessary() throws IOException { - NioSocketChannel channel = new DoNotRegisterChannel(rawChannel, mock(SocketSelector.class)); - channel.setSelectionKey(new TestSelectionKey(SelectionKey.OP_READ)); + TestSelectionKey selectionKey = new TestSelectionKey(SelectionKey.OP_READ); SocketChannelContext context = mock(SocketChannelContext.class); - channel.setContext(context); - + when(context.getSelectionKey()).thenReturn(selectionKey); when(context.hasQueuedWriteOps()).thenReturn(true); - assertEquals(SelectionKey.OP_READ, channel.getSelectionKey().interestOps()); - handler.postHandling(channel); - assertEquals(SelectionKey.OP_READ | SelectionKey.OP_WRITE, channel.getSelectionKey().interestOps()); + NioSocketChannel channel = mock(NioSocketChannel.class); + when(channel.getContext()).thenReturn(context); + + assertEquals(SelectionKey.OP_READ, selectionKey.interestOps()); + handler.postHandling(context); + assertEquals(SelectionKey.OP_READ | SelectionKey.OP_WRITE, selectionKey.interestOps()); } public void testPostHandlingWillRemoveWriteIfNecessary() throws IOException { - NioSocketChannel channel = new DoNotRegisterChannel(rawChannel, mock(SocketSelector.class)); - channel.setSelectionKey(new TestSelectionKey(SelectionKey.OP_READ | SelectionKey.OP_WRITE)); + TestSelectionKey key = new TestSelectionKey(SelectionKey.OP_READ | SelectionKey.OP_WRITE); SocketChannelContext context = mock(SocketChannelContext.class); - channel.setContext(context); - + when(context.getSelectionKey()).thenReturn(key); when(context.hasQueuedWriteOps()).thenReturn(false); - assertEquals(SelectionKey.OP_READ | SelectionKey.OP_WRITE, channel.getSelectionKey().interestOps()); - handler.postHandling(channel); - assertEquals(SelectionKey.OP_READ, channel.getSelectionKey().interestOps()); + NioSocketChannel channel = mock(NioSocketChannel.class); + when(channel.getContext()).thenReturn(context); + + + assertEquals(SelectionKey.OP_READ | SelectionKey.OP_WRITE, key.interestOps()); + handler.postHandling(context); + assertEquals(SelectionKey.OP_READ, key.interestOps()); + } + + private class DoNotRegisterContext extends BytesChannelContext { + + private final TestSelectionKey selectionKey; + + DoNotRegisterContext(NioSocketChannel channel, SocketSelector selector, Consumer exceptionHandler, + TestSelectionKey selectionKey) { + super(channel, selector, exceptionHandler, mock(ReadConsumer.class), InboundChannelBuffer.allocatingInstance()); + this.selectionKey = selectionKey; + } + + @Override + public void register() { + setSelectionKey(selectionKey); + } } } diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketSelectorTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketSelectorTests.java index 5992244b2f9..223f14455f9 100644 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketSelectorTests.java +++ b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketSelectorTests.java @@ -58,22 +58,22 @@ public class SocketSelectorTests extends ESTestCase { @SuppressWarnings("unchecked") public void setUp() throws Exception { super.setUp(); + rawSelector = mock(Selector.class); eventHandler = mock(SocketEventHandler.class); channel = mock(NioSocketChannel.class); channelContext = mock(SocketChannelContext.class); listener = mock(BiConsumer.class); selectionKey = new TestSelectionKey(0); - selectionKey.attach(channel); - rawSelector = mock(Selector.class); + selectionKey.attach(channelContext); this.socketSelector = new SocketSelector(eventHandler, rawSelector); this.socketSelector.setThread(); - when(channel.isOpen()).thenReturn(true); - when(channel.getSelectionKey()).thenReturn(selectionKey); when(channel.getContext()).thenReturn(channelContext); - when(channel.isConnectComplete()).thenReturn(true); - when(channel.getSelector()).thenReturn(socketSelector); + when(channelContext.isOpen()).thenReturn(true); + when(channelContext.getSelector()).thenReturn(socketSelector); + when(channelContext.getSelectionKey()).thenReturn(selectionKey); + when(channelContext.isConnectComplete()).thenReturn(true); } public void testRegisterChannel() throws Exception { @@ -81,64 +81,52 @@ public class SocketSelectorTests extends ESTestCase { socketSelector.preSelect(); - verify(eventHandler).handleRegistration(channel); + verify(eventHandler).handleRegistration(channelContext); } public void testClosedChannelWillNotBeRegistered() throws Exception { - when(channel.isOpen()).thenReturn(false); + when(channelContext.isOpen()).thenReturn(false); socketSelector.scheduleForRegistration(channel); socketSelector.preSelect(); - verify(eventHandler).registrationException(same(channel), any(ClosedChannelException.class)); - verify(channel, times(0)).finishConnect(); + verify(eventHandler).registrationException(same(channelContext), any(ClosedChannelException.class)); + verify(eventHandler, times(0)).handleConnect(channelContext); } public void testRegisterChannelFailsDueToException() throws Exception { socketSelector.scheduleForRegistration(channel); ClosedChannelException closedChannelException = new ClosedChannelException(); - doThrow(closedChannelException).when(channel).register(); + doThrow(closedChannelException).when(eventHandler).handleRegistration(channelContext); socketSelector.preSelect(); - verify(eventHandler).registrationException(channel, closedChannelException); - verify(channel, times(0)).finishConnect(); + verify(eventHandler).registrationException(channelContext, closedChannelException); + verify(eventHandler, times(0)).handleConnect(channelContext); } - public void testSuccessfullyRegisterChannelWillConnect() throws Exception { + public void testSuccessfullyRegisterChannelWillAttemptConnect() throws Exception { socketSelector.scheduleForRegistration(channel); - when(channel.finishConnect()).thenReturn(true); - socketSelector.preSelect(); - verify(eventHandler).handleConnect(channel); - } - - public void testConnectIncompleteWillNotNotify() throws Exception { - socketSelector.scheduleForRegistration(channel); - - when(channel.finishConnect()).thenReturn(false); - - socketSelector.preSelect(); - - verify(eventHandler, times(0)).handleConnect(channel); + verify(eventHandler).handleConnect(channelContext); } public void testQueueWriteWhenNotRunning() throws Exception { socketSelector.close(); - socketSelector.queueWrite(new BytesWriteOperation(channel, buffers, listener)); + socketSelector.queueWrite(new BytesWriteOperation(channelContext, buffers, listener)); verify(listener).accept(isNull(Void.class), any(ClosedSelectorException.class)); } public void testQueueWriteChannelIsClosed() throws Exception { - BytesWriteOperation writeOperation = new BytesWriteOperation(channel, buffers, listener); + BytesWriteOperation writeOperation = new BytesWriteOperation(channelContext, buffers, listener); socketSelector.queueWrite(writeOperation); - when(channel.isOpen()).thenReturn(false); + when(channelContext.isOpen()).thenReturn(false); socketSelector.preSelect(); verify(channelContext, times(0)).queueWriteOperation(writeOperation); @@ -148,11 +136,11 @@ public class SocketSelectorTests extends ESTestCase { public void testQueueWriteSelectionKeyThrowsException() throws Exception { SelectionKey selectionKey = mock(SelectionKey.class); - BytesWriteOperation writeOperation = new BytesWriteOperation(channel, buffers, listener); + BytesWriteOperation writeOperation = new BytesWriteOperation(channelContext, buffers, listener); CancelledKeyException cancelledKeyException = new CancelledKeyException(); socketSelector.queueWrite(writeOperation); - when(channel.getSelectionKey()).thenReturn(selectionKey); + when(channelContext.getSelectionKey()).thenReturn(selectionKey); when(selectionKey.interestOps(anyInt())).thenThrow(cancelledKeyException); socketSelector.preSelect(); @@ -161,7 +149,7 @@ public class SocketSelectorTests extends ESTestCase { } public void testQueueWriteSuccessful() throws Exception { - BytesWriteOperation writeOperation = new BytesWriteOperation(channel, buffers, listener); + BytesWriteOperation writeOperation = new BytesWriteOperation(channelContext, buffers, listener); socketSelector.queueWrite(writeOperation); assertTrue((selectionKey.interestOps() & SelectionKey.OP_WRITE) == 0); @@ -173,7 +161,7 @@ public class SocketSelectorTests extends ESTestCase { } public void testQueueDirectlyInChannelBufferSuccessful() throws Exception { - BytesWriteOperation writeOperation = new BytesWriteOperation(channel, buffers, listener); + BytesWriteOperation writeOperation = new BytesWriteOperation(channelContext, buffers, listener); assertTrue((selectionKey.interestOps() & SelectionKey.OP_WRITE) == 0); @@ -186,10 +174,10 @@ public class SocketSelectorTests extends ESTestCase { public void testQueueDirectlyInChannelBufferSelectionKeyThrowsException() throws Exception { SelectionKey selectionKey = mock(SelectionKey.class); - BytesWriteOperation writeOperation = new BytesWriteOperation(channel, buffers, listener); + BytesWriteOperation writeOperation = new BytesWriteOperation(channelContext, buffers, listener); CancelledKeyException cancelledKeyException = new CancelledKeyException(); - when(channel.getSelectionKey()).thenReturn(selectionKey); + when(channelContext.getSelectionKey()).thenReturn(selectionKey); when(selectionKey.interestOps(anyInt())).thenThrow(cancelledKeyException); socketSelector.queueWriteInChannelBuffer(writeOperation); @@ -200,19 +188,9 @@ public class SocketSelectorTests extends ESTestCase { public void testConnectEvent() throws Exception { selectionKey.setReadyOps(SelectionKey.OP_CONNECT); - when(channel.finishConnect()).thenReturn(true); socketSelector.processKey(selectionKey); - verify(eventHandler).handleConnect(channel); - } - - public void testConnectEventFinishUnsuccessful() throws Exception { - selectionKey.setReadyOps(SelectionKey.OP_CONNECT); - - when(channel.finishConnect()).thenReturn(false); - socketSelector.processKey(selectionKey); - - verify(eventHandler, times(0)).handleConnect(channel); + verify(eventHandler).handleConnect(channelContext); } public void testConnectEventFinishThrowException() throws Exception { @@ -220,11 +198,10 @@ public class SocketSelectorTests extends ESTestCase { selectionKey.setReadyOps(SelectionKey.OP_CONNECT); - when(channel.finishConnect()).thenThrow(ioException); + doThrow(ioException).when(eventHandler).handleConnect(channelContext); socketSelector.processKey(selectionKey); - verify(eventHandler, times(0)).handleConnect(channel); - verify(eventHandler).connectException(channel, ioException); + verify(eventHandler).connectException(channelContext, ioException); } public void testWillNotConsiderWriteOrReadUntilConnectionComplete() throws Exception { @@ -232,13 +209,13 @@ public class SocketSelectorTests extends ESTestCase { selectionKey.setReadyOps(SelectionKey.OP_WRITE | SelectionKey.OP_READ); - doThrow(ioException).when(eventHandler).handleWrite(channel); + doThrow(ioException).when(eventHandler).handleWrite(channelContext); - when(channel.isConnectComplete()).thenReturn(false); + when(channelContext.isConnectComplete()).thenReturn(false); socketSelector.processKey(selectionKey); - verify(eventHandler, times(0)).handleWrite(channel); - verify(eventHandler, times(0)).handleRead(channel); + verify(eventHandler, times(0)).handleWrite(channelContext); + verify(eventHandler, times(0)).handleRead(channelContext); } public void testSuccessfulWriteEvent() throws Exception { @@ -246,7 +223,7 @@ public class SocketSelectorTests extends ESTestCase { socketSelector.processKey(selectionKey); - verify(eventHandler).handleWrite(channel); + verify(eventHandler).handleWrite(channelContext); } public void testWriteEventWithException() throws Exception { @@ -254,11 +231,11 @@ public class SocketSelectorTests extends ESTestCase { selectionKey.setReadyOps(SelectionKey.OP_WRITE); - doThrow(ioException).when(eventHandler).handleWrite(channel); + doThrow(ioException).when(eventHandler).handleWrite(channelContext); socketSelector.processKey(selectionKey); - verify(eventHandler).writeException(channel, ioException); + verify(eventHandler).writeException(channelContext, ioException); } public void testSuccessfulReadEvent() throws Exception { @@ -266,7 +243,7 @@ public class SocketSelectorTests extends ESTestCase { socketSelector.processKey(selectionKey); - verify(eventHandler).handleRead(channel); + verify(eventHandler).handleRead(channelContext); } public void testReadEventWithException() throws Exception { @@ -274,11 +251,11 @@ public class SocketSelectorTests extends ESTestCase { selectionKey.setReadyOps(SelectionKey.OP_READ); - doThrow(ioException).when(eventHandler).handleRead(channel); + doThrow(ioException).when(eventHandler).handleRead(channelContext); socketSelector.processKey(selectionKey); - verify(eventHandler).readException(channel, ioException); + verify(eventHandler).readException(channelContext, ioException); } public void testWillCallPostHandleAfterChannelHandling() throws Exception { @@ -286,30 +263,32 @@ public class SocketSelectorTests extends ESTestCase { socketSelector.processKey(selectionKey); - verify(eventHandler).handleWrite(channel); - verify(eventHandler).handleRead(channel); - verify(eventHandler).postHandling(channel); + verify(eventHandler).handleWrite(channelContext); + verify(eventHandler).handleRead(channelContext); + verify(eventHandler).postHandling(channelContext); } public void testCleanup() throws Exception { - NioSocketChannel unRegisteredChannel = mock(NioSocketChannel.class); + NioSocketChannel unregisteredChannel = mock(NioSocketChannel.class); + SocketChannelContext unregisteredContext = mock(SocketChannelContext.class); + when(unregisteredChannel.getContext()).thenReturn(unregisteredContext); socketSelector.scheduleForRegistration(channel); socketSelector.preSelect(); - socketSelector.queueWrite(new BytesWriteOperation(mock(NioSocketChannel.class), buffers, listener)); - socketSelector.scheduleForRegistration(unRegisteredChannel); + socketSelector.queueWrite(new BytesWriteOperation(channelContext, buffers, listener)); + socketSelector.scheduleForRegistration(unregisteredChannel); TestSelectionKey testSelectionKey = new TestSelectionKey(0); - testSelectionKey.attach(channel); + testSelectionKey.attach(channelContext); when(rawSelector.keys()).thenReturn(new HashSet<>(Collections.singletonList(testSelectionKey))); socketSelector.cleanupAndCloseChannels(); verify(listener).accept(isNull(Void.class), any(ClosedSelectorException.class)); - verify(eventHandler).handleClose(channel); - verify(eventHandler).handleClose(unRegisteredChannel); + verify(eventHandler).handleClose(channelContext); + verify(eventHandler).handleClose(unregisteredContext); } public void testExecuteListenerWillHandleException() throws Exception { diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java index acea1ca5d48..eb3d7f3d710 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java @@ -31,15 +31,15 @@ import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.PageCacheRecycler; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.indices.breaker.CircuitBreakerService; -import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.nio.AcceptingSelector; import org.elasticsearch.nio.AcceptorEventHandler; +import org.elasticsearch.nio.BytesChannelContext; import org.elasticsearch.nio.ChannelFactory; import org.elasticsearch.nio.InboundChannelBuffer; -import org.elasticsearch.nio.BytesChannelContext; import org.elasticsearch.nio.NioGroup; import org.elasticsearch.nio.NioSocketChannel; import org.elasticsearch.nio.ServerChannelContext; +import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.nio.SocketEventHandler; import org.elasticsearch.nio.SocketSelector; import org.elasticsearch.threadpool.ThreadPool; @@ -53,7 +53,7 @@ import java.nio.ByteBuffer; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; import java.util.concurrent.ConcurrentMap; -import java.util.function.BiConsumer; +import java.util.function.Consumer; import java.util.function.Supplier; import static org.elasticsearch.common.settings.Setting.intSetting; @@ -179,15 +179,15 @@ public class NioTransport extends TcpTransport { @Override public TcpNioSocketChannel createChannel(SocketSelector selector, SocketChannel channel) throws IOException { - TcpNioSocketChannel nioChannel = new TcpNioSocketChannel(profileName, channel, selector); + TcpNioSocketChannel nioChannel = new TcpNioSocketChannel(profileName, channel); Supplier pageSupplier = () -> { Recycler.V bytes = pageCacheRecycler.bytePage(false); return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close); }; SocketChannelContext.ReadConsumer nioReadConsumer = channelBuffer -> consumeNetworkReads(nioChannel, BytesReference.fromByteBuffers(channelBuffer.sliceBuffersTo(channelBuffer.getIndex()))); - BiConsumer exceptionHandler = NioTransport.this::exceptionCaught; - BytesChannelContext context = new BytesChannelContext(nioChannel, exceptionHandler, nioReadConsumer, + Consumer exceptionHandler = (e) -> exceptionCaught(nioChannel, e); + BytesChannelContext context = new BytesChannelContext(nioChannel, selector, exceptionHandler, nioReadConsumer, new InboundChannelBuffer(pageSupplier)); nioChannel.setContext(context); return nioChannel; @@ -195,8 +195,9 @@ public class NioTransport extends TcpTransport { @Override public TcpNioServerSocketChannel createServerChannel(AcceptingSelector selector, ServerSocketChannel channel) throws IOException { - TcpNioServerSocketChannel nioChannel = new TcpNioServerSocketChannel(profileName, channel, this, selector); - ServerChannelContext context = new ServerChannelContext(nioChannel, NioTransport.this::acceptChannel, (c, e) -> {}); + TcpNioServerSocketChannel nioChannel = new TcpNioServerSocketChannel(profileName, channel); + ServerChannelContext context = new ServerChannelContext(nioChannel, this, selector, NioTransport.this::acceptChannel, + (e) -> {}); nioChannel.setContext(context); return nioChannel; } diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpNioServerSocketChannel.java b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpNioServerSocketChannel.java index 683ae146cfb..c63acc9f4de 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpNioServerSocketChannel.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpNioServerSocketChannel.java @@ -38,10 +38,8 @@ public class TcpNioServerSocketChannel extends NioServerSocketChannel implements private final String profile; - public TcpNioServerSocketChannel(String profile, ServerSocketChannel socketChannel, - ChannelFactory channelFactory, - AcceptingSelector selector) throws IOException { - super(socketChannel, channelFactory, selector); + public TcpNioServerSocketChannel(String profile, ServerSocketChannel socketChannel) throws IOException { + super(socketChannel); this.profile = profile; } @@ -62,7 +60,7 @@ public class TcpNioServerSocketChannel extends NioServerSocketChannel implements @Override public void close() { - getSelector().queueChannelClose(this); + getContext().closeChannel(); } @Override diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpNioSocketChannel.java b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpNioSocketChannel.java index c2064e53ca6..44ab17457e8 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpNioSocketChannel.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpNioSocketChannel.java @@ -33,8 +33,8 @@ public class TcpNioSocketChannel extends NioSocketChannel implements TcpChannel private final String profile; - public TcpNioSocketChannel(String profile, SocketChannel socketChannel, SocketSelector selector) throws IOException { - super(socketChannel, selector); + public TcpNioSocketChannel(String profile, SocketChannel socketChannel) throws IOException { + super(socketChannel); this.profile = profile; } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java index ec262261e54..5271ac6a148 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java @@ -164,8 +164,8 @@ public class MockNioTransport extends TcpTransport { }; SocketChannelContext.ReadConsumer nioReadConsumer = channelBuffer -> consumeNetworkReads(nioChannel, BytesReference.fromByteBuffers(channelBuffer.sliceBuffersTo(channelBuffer.getIndex()))); - BytesChannelContext context = new BytesChannelContext(nioChannel, MockNioTransport.this::exceptionCaught, nioReadConsumer, - new InboundChannelBuffer(pageSupplier)); + BytesChannelContext context = new BytesChannelContext(nioChannel, selector, (e) -> exceptionCaught(nioChannel, e), + nioReadConsumer, new InboundChannelBuffer(pageSupplier)); nioChannel.setContext(context); return nioChannel; } @@ -173,7 +173,8 @@ public class MockNioTransport extends TcpTransport { @Override public MockServerChannel createServerChannel(AcceptingSelector selector, ServerSocketChannel channel) throws IOException { MockServerChannel nioServerChannel = new MockServerChannel(profileName, channel, this, selector); - ServerChannelContext context = new ServerChannelContext(nioServerChannel, MockNioTransport.this::acceptChannel, (c, e) -> {}); + ServerChannelContext context = new ServerChannelContext(nioServerChannel, this, selector, MockNioTransport.this::acceptChannel, + (e) -> {}); nioServerChannel.setContext(context); return nioServerChannel; } @@ -185,13 +186,13 @@ public class MockNioTransport extends TcpTransport { MockServerChannel(String profile, ServerSocketChannel channel, ChannelFactory channelFactory, AcceptingSelector selector) throws IOException { - super(channel, channelFactory, selector); + super(channel); this.profile = profile; } @Override public void close() { - getSelector().queueChannelClose(this); + getContext().closeChannel(); } @Override @@ -226,7 +227,7 @@ public class MockNioTransport extends TcpTransport { private MockSocketChannel(String profile, java.nio.channels.SocketChannel socketChannel, SocketSelector selector) throws IOException { - super(socketChannel, selector); + super(socketChannel); this.profile = profile; } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/TestingSocketEventHandler.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/TestingSocketEventHandler.java index ecc00c24f9c..2e2d8aa5ada 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/TestingSocketEventHandler.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/TestingSocketEventHandler.java @@ -20,7 +20,7 @@ package org.elasticsearch.transport.nio; import org.apache.logging.log4j.Logger; -import org.elasticsearch.nio.NioSocketChannel; +import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.nio.SocketEventHandler; import java.io.IOException; @@ -34,36 +34,21 @@ public class TestingSocketEventHandler extends SocketEventHandler { super(logger); } - private Set hasConnectedMap = Collections.newSetFromMap(new WeakHashMap<>()); + private Set hasConnectedMap = Collections.newSetFromMap(new WeakHashMap<>()); - public void handleConnect(NioSocketChannel channel) { - assert hasConnectedMap.contains(channel) == false : "handleConnect should only be called once per channel"; - hasConnectedMap.add(channel); - super.handleConnect(channel); + public void handleConnect(SocketChannelContext context) throws IOException { + assert hasConnectedMap.contains(context) == false : "handleConnect should only be called is a channel is not yet connected"; + super.handleConnect(context); + if (context.isConnectComplete()) { + hasConnectedMap.add(context); + } } - private Set hasConnectExceptionMap = Collections.newSetFromMap(new WeakHashMap<>()); + private Set hasConnectExceptionMap = Collections.newSetFromMap(new WeakHashMap<>()); - public void connectException(NioSocketChannel channel, Exception e) { - assert hasConnectExceptionMap.contains(channel) == false : "connectException should only called at maximum once per channel"; - hasConnectExceptionMap.add(channel); - super.connectException(channel, e); + public void connectException(SocketChannelContext context, Exception e) { + assert hasConnectExceptionMap.contains(context) == false : "connectException should only called at maximum once per channel"; + hasConnectExceptionMap.add(context); + super.connectException(context, e); } - - public void handleRead(NioSocketChannel channel) throws IOException { - super.handleRead(channel); - } - - public void readException(NioSocketChannel channel, Exception e) { - super.readException(channel, e); - } - - public void handleWrite(NioSocketChannel channel) throws IOException { - super.handleWrite(channel); - } - - public void writeException(NioSocketChannel channel, Exception e) { - super.writeException(channel, e); - } - }