diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/AbstractNioChannel.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/AbstractNioChannel.java index 8285fef6d39..e3dcbad024c 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/AbstractNioChannel.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/AbstractNioChannel.java @@ -26,7 +26,6 @@ import java.nio.channels.NetworkChannel; import java.nio.channels.SelectableChannel; import java.nio.channels.SelectionKey; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.BiConsumer; /** @@ -48,9 +47,6 @@ import java.util.function.BiConsumer; public abstract class AbstractNioChannel implements NioChannel { final S socketChannel; - // This indicates if the channel has been scheduled to be closed. Read the closeFuture to determine if - // the channel close process has completed. - final AtomicBoolean isClosing = new AtomicBoolean(false); private final InetSocketAddress localAddress; private final CompletableFuture closeContext = new CompletableFuture<>(); @@ -73,21 +69,6 @@ public abstract class AbstractNioChannel - * If the channel is open and the state can be transitioned to closed, the close operation will - * be scheduled with the event loop. - *

- * If the channel is already set to closed, it is assumed that it is already scheduled to be closed. - */ - @Override - public void close() { - if (isClosing.compareAndSet(false, true)) { - selector.queueChannelClose(this); - } - } - /** * Closes the channel synchronously. This method should only be called from the selector thread. *

@@ -95,8 +76,7 @@ public abstract class AbstractNioChannel queued = new LinkedList<>(); + private final AtomicBoolean isClosing = new AtomicBoolean(false); + private boolean peerClosed = false; + private boolean ioException = false; + + public BytesChannelContext(NioSocketChannel channel, ReadConsumer readConsumer, InboundChannelBuffer channelBuffer) { + this.channel = channel; + this.readConsumer = readConsumer; + this.channelBuffer = channelBuffer; + } + + @Override + public void channelRegistered() throws IOException {} + + @Override + public int read() throws IOException { + if (channelBuffer.getRemaining() == 0) { + // Requiring one additional byte will ensure that a new page is allocated. + channelBuffer.ensureCapacity(channelBuffer.getCapacity() + 1); + } + + int bytesRead; + try { + bytesRead = channel.read(channelBuffer.sliceBuffersFrom(channelBuffer.getIndex())); + } catch (IOException ex) { + ioException = true; + throw ex; + } + + if (bytesRead == -1) { + peerClosed = true; + return 0; + } + + channelBuffer.incrementIndex(bytesRead); + + int bytesConsumed = Integer.MAX_VALUE; + while (bytesConsumed > 0 && channelBuffer.getIndex() > 0) { + bytesConsumed = readConsumer.consumeReads(channelBuffer); + channelBuffer.release(bytesConsumed); + } + + return bytesRead; + } + + @Override + public void sendMessage(ByteBuffer[] buffers, BiConsumer listener) { + if (isClosing.get()) { + listener.accept(null, new ClosedChannelException()); + return; + } + + BytesWriteOperation writeOperation = new BytesWriteOperation(channel, buffers, listener); + SocketSelector selector = channel.getSelector(); + if (selector.isOnCurrentThread() == false) { + selector.queueWrite(writeOperation); + return; + } + + // TODO: Eval if we will allow writes from sendMessage + selector.queueWriteInChannelBuffer(writeOperation); + } + + @Override + public void queueWriteOperation(WriteOperation writeOperation) { + channel.getSelector().assertOnSelectorThread(); + queued.add((BytesWriteOperation) writeOperation); + } + + @Override + public void flushChannel() throws IOException { + channel.getSelector().assertOnSelectorThread(); + int ops = queued.size(); + if (ops == 1) { + singleFlush(queued.pop()); + } else if (ops > 1) { + multiFlush(); + } + } + + @Override + public boolean hasQueuedWriteOps() { + channel.getSelector().assertOnSelectorThread(); + return queued.isEmpty() == false; + } + + @Override + public void closeChannel() { + if (isClosing.compareAndSet(false, true)) { + channel.getSelector().queueChannelClose(channel); + } + } + + @Override + public boolean selectorShouldClose() { + return peerClosed || ioException || isClosing.get(); + } + + @Override + public void closeFromSelector() { + channel.getSelector().assertOnSelectorThread(); + // Set to true in order to reject new writes before queuing with selector + isClosing.set(true); + channelBuffer.close(); + for (BytesWriteOperation op : queued) { + channel.getSelector().executeFailedListener(op.getListener(), new ClosedChannelException()); + } + queued.clear(); + } + + private void singleFlush(BytesWriteOperation headOp) throws IOException { + try { + int written = channel.write(headOp.getBuffersToWrite()); + headOp.incrementIndex(written); + } catch (IOException e) { + channel.getSelector().executeFailedListener(headOp.getListener(), e); + ioException = true; + throw e; + } + + if (headOp.isFullyFlushed()) { + channel.getSelector().executeListener(headOp.getListener(), null); + } else { + queued.push(headOp); + } + } + + private void multiFlush() throws IOException { + boolean lastOpCompleted = true; + while (lastOpCompleted && queued.isEmpty() == false) { + BytesWriteOperation op = queued.pop(); + singleFlush(op); + lastOpCompleted = op.isFullyFlushed(); + } + } +} diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesReadContext.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesReadContext.java deleted file mode 100644 index eeda147be6c..00000000000 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesReadContext.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.nio; - -import java.io.IOException; - -public class BytesReadContext implements ReadContext { - - private final NioSocketChannel channel; - private final ReadConsumer readConsumer; - private final InboundChannelBuffer channelBuffer; - - public BytesReadContext(NioSocketChannel channel, ReadConsumer readConsumer, InboundChannelBuffer channelBuffer) { - this.channel = channel; - this.channelBuffer = channelBuffer; - this.readConsumer = readConsumer; - } - - @Override - public int read() throws IOException { - if (channelBuffer.getRemaining() == 0) { - // Requiring one additional byte will ensure that a new page is allocated. - channelBuffer.ensureCapacity(channelBuffer.getCapacity() + 1); - } - - int bytesRead = channel.read(channelBuffer.sliceBuffersFrom(channelBuffer.getIndex())); - - if (bytesRead == -1) { - return bytesRead; - } - - channelBuffer.incrementIndex(bytesRead); - - int bytesConsumed = Integer.MAX_VALUE; - while (bytesConsumed > 0) { - bytesConsumed = readConsumer.consumeReads(channelBuffer); - channelBuffer.release(bytesConsumed); - } - - return bytesRead; - } - - @Override - public void close() { - channelBuffer.close(); - } -} diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesWriteContext.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesWriteContext.java deleted file mode 100644 index c2816deef53..00000000000 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesWriteContext.java +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.nio; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.channels.ClosedChannelException; -import java.util.LinkedList; -import java.util.function.BiConsumer; - -public class BytesWriteContext implements WriteContext { - - private final NioSocketChannel channel; - private final LinkedList queued = new LinkedList<>(); - - public BytesWriteContext(NioSocketChannel channel) { - this.channel = channel; - } - - @Override - public void sendMessage(Object message, BiConsumer listener) { - ByteBuffer[] buffers = (ByteBuffer[]) message; - if (channel.isWritable() == false) { - listener.accept(null, new ClosedChannelException()); - return; - } - - WriteOperation writeOperation = new WriteOperation(channel, buffers, listener); - SocketSelector selector = channel.getSelector(); - if (selector.isOnCurrentThread() == false) { - selector.queueWrite(writeOperation); - return; - } - - // TODO: Eval if we will allow writes from sendMessage - selector.queueWriteInChannelBuffer(writeOperation); - } - - @Override - public void queueWriteOperations(WriteOperation writeOperation) { - assert channel.getSelector().isOnCurrentThread() : "Must be on selector thread to queue writes"; - queued.add(writeOperation); - } - - @Override - public void flushChannel() throws IOException { - assert channel.getSelector().isOnCurrentThread() : "Must be on selector thread to flush writes"; - int ops = queued.size(); - if (ops == 1) { - singleFlush(queued.pop()); - } else if (ops > 1) { - multiFlush(); - } - } - - @Override - public boolean hasQueuedWriteOps() { - assert channel.getSelector().isOnCurrentThread() : "Must be on selector thread to access queued writes"; - return queued.isEmpty() == false; - } - - @Override - public void clearQueuedWriteOps(Exception e) { - assert channel.getSelector().isOnCurrentThread() : "Must be on selector thread to clear queued writes"; - for (WriteOperation op : queued) { - channel.getSelector().executeFailedListener(op.getListener(), e); - } - queued.clear(); - } - - private void singleFlush(WriteOperation headOp) throws IOException { - try { - headOp.flush(); - } catch (IOException e) { - channel.getSelector().executeFailedListener(headOp.getListener(), e); - throw e; - } - - if (headOp.isFullyFlushed()) { - channel.getSelector().executeListener(headOp.getListener(), null); - } else { - queued.push(headOp); - } - } - - private void multiFlush() throws IOException { - boolean lastOpCompleted = true; - while (lastOpCompleted && queued.isEmpty() == false) { - WriteOperation op = queued.pop(); - singleFlush(op); - lastOpCompleted = op.isFullyFlushed(); - } - } -} diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesWriteOperation.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesWriteOperation.java new file mode 100644 index 00000000000..14e8cace66d --- /dev/null +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesWriteOperation.java @@ -0,0 +1,88 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.nio; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.function.BiConsumer; + +public class BytesWriteOperation implements WriteOperation { + + private final NioSocketChannel channel; + private final BiConsumer listener; + private final ByteBuffer[] buffers; + private final int[] offsets; + private final int length; + private int internalIndex; + + public BytesWriteOperation(NioSocketChannel channel, ByteBuffer[] buffers, BiConsumer listener) { + this.channel = channel; + this.listener = listener; + this.buffers = buffers; + this.offsets = new int[buffers.length]; + int offset = 0; + for (int i = 0; i < buffers.length; i++) { + ByteBuffer buffer = buffers[i]; + offsets[i] = offset; + offset += buffer.remaining(); + } + length = offset; + } + + @Override + public BiConsumer getListener() { + return listener; + } + + @Override + public NioSocketChannel getChannel() { + return channel; + } + + public boolean isFullyFlushed() { + assert length >= internalIndex : "Should never have an index that is greater than the length [length=" + length + ", index=" + + internalIndex + "]"; + return internalIndex == length; + } + + public void incrementIndex(int delta) { + internalIndex += delta; + assert length >= internalIndex : "Should never increment index past length [length=" + length + ", post-increment index=" + + internalIndex + ", delta=" + delta + "]"; + } + + public ByteBuffer[] getBuffersToWrite() { + final int index = Arrays.binarySearch(offsets, internalIndex); + int offsetIndex = index < 0 ? (-(index + 1)) - 1 : index; + + ByteBuffer[] postIndexBuffers = new ByteBuffer[buffers.length - offsetIndex]; + + ByteBuffer firstBuffer = buffers[offsetIndex].duplicate(); + firstBuffer.position(internalIndex - offsets[offsetIndex]); + postIndexBuffers[0] = firstBuffer; + int j = 1; + for (int i = (offsetIndex + 1); i < buffers.length; ++i) { + postIndexBuffers[j++] = buffers[i].duplicate(); + } + + return postIndexBuffers; + } + +} diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ChannelContext.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ChannelContext.java new file mode 100644 index 00000000000..10afd53621d --- /dev/null +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ChannelContext.java @@ -0,0 +1,81 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.nio; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.function.BiConsumer; + +/** + * This context should implement the specific logic for a channel. When a channel receives a notification + * that it is ready to perform certain operations (read, write, etc) the {@link ChannelContext} will be + * called. This context will need to implement all protocol related logic. Additionally, if any special + * close behavior is required, it should be implemented in this context. + * + * The only methods of the context that should ever be called from a non-selector thread are + * {@link #closeChannel()} and {@link #sendMessage(ByteBuffer[], BiConsumer)}. + */ +public interface ChannelContext { + + void channelRegistered() throws IOException; + + int read() throws IOException; + + void sendMessage(ByteBuffer[] buffers, BiConsumer listener); + + void queueWriteOperation(WriteOperation writeOperation); + + void flushChannel() throws IOException; + + boolean hasQueuedWriteOps(); + + /** + * Schedules a channel to be closed by the selector event loop with which it is registered. + *

+ * If the channel is open and the state can be transitioned to closed, the close operation will + * be scheduled with the event loop. + *

+ * If the channel is already set to closed, it is assumed that it is already scheduled to be closed. + *

+ * Depending on the underlying protocol of the channel, a close operation might simply close the socket + * channel or may involve reading and writing messages. + */ + void closeChannel(); + + /** + * This method indicates if a selector should close this channel. + * + * @return a boolean indicating if the selector should close + */ + boolean selectorShouldClose(); + + /** + * This method cleans up any context resources that need to be released when a channel is closed. It + * should only be called by the selector thread. + * + * @throws IOException during channel / context close + */ + void closeFromSelector() throws IOException; + + @FunctionalInterface + interface ReadConsumer { + int consumeReads(InboundChannelBuffer channelBuffer) throws IOException; + } +} diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ChannelFactory.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ChannelFactory.java index d90927af8b9..a9909587453 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ChannelFactory.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ChannelFactory.java @@ -88,8 +88,7 @@ public abstract class ChannelFactory { private final CompletableFuture connectContext = new CompletableFuture<>(); private final SocketSelector socketSelector; private final AtomicBoolean contextsSet = new AtomicBoolean(false); - private WriteContext writeContext; - private ReadContext readContext; + private ChannelContext context; private BiConsumer exceptionContext; private Exception connectException; @@ -47,14 +48,21 @@ public class NioSocketChannel extends AbstractNioChannel { @Override public void closeFromSelector() throws IOException { - assert socketSelector.isOnCurrentThread() : "Should only call from selector thread"; - // Even if the channel has already been closed we will clear any pending write operations just in case - if (writeContext.hasQueuedWriteOps()) { - writeContext.clearQueuedWriteOps(new ClosedChannelException()); + getSelector().assertOnSelectorThread(); + if (isOpen()) { + ArrayList closingExceptions = new ArrayList<>(2); + try { + super.closeFromSelector(); + } catch (IOException e) { + closingExceptions.add(e); + } + try { + context.closeFromSelector(); + } catch (IOException e) { + closingExceptions.add(e); + } + ExceptionsHelper.rethrowAndSuppress(closingExceptions); } - readContext.close(); - - super.closeFromSelector(); } @Override @@ -62,6 +70,10 @@ public class NioSocketChannel extends AbstractNioChannel { return socketSelector; } + public int write(ByteBuffer buffer) throws IOException { + return socketChannel.write(buffer); + } + public int write(ByteBuffer[] buffers) throws IOException { if (buffers.length == 1) { return socketChannel.write(buffers[0]); @@ -82,33 +94,17 @@ public class NioSocketChannel extends AbstractNioChannel { } } - public int read(InboundChannelBuffer buffer) throws IOException { - int bytesRead = (int) socketChannel.read(buffer.sliceBuffersFrom(buffer.getIndex())); - - if (bytesRead == -1) { - return bytesRead; - } - - buffer.incrementIndex(bytesRead); - return bytesRead; - } - - public void setContexts(ReadContext readContext, WriteContext writeContext, BiConsumer exceptionContext) { + public void setContexts(ChannelContext context, BiConsumer exceptionContext) { if (contextsSet.compareAndSet(false, true)) { - this.readContext = readContext; - this.writeContext = writeContext; + this.context = context; this.exceptionContext = exceptionContext; } else { throw new IllegalStateException("Contexts on this channel were already set. They should only be once."); } } - public WriteContext getWriteContext() { - return writeContext; - } - - public ReadContext getReadContext() { - return readContext; + public ChannelContext getContext() { + return context; } public BiConsumer getExceptionContext() { @@ -123,14 +119,6 @@ public class NioSocketChannel extends AbstractNioChannel { return isConnectComplete0(); } - public boolean isWritable() { - return isClosing.get() == false; - } - - public boolean isReadable() { - return isClosing.get() == false; - } - /** * This method will attempt to complete the connection process for this channel. It should be called for * new channels or for a channel that has produced a OP_CONNECT event. If this method returns true then diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ReadContext.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ReadContext.java deleted file mode 100644 index d23ce56f57a..00000000000 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ReadContext.java +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.nio; - -import java.io.IOException; - -public interface ReadContext extends AutoCloseable { - - int read() throws IOException; - - @Override - void close(); - - @FunctionalInterface - interface ReadConsumer { - int consumeReads(InboundChannelBuffer channelBuffer) throws IOException; - } -} diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SelectionKeyUtils.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SelectionKeyUtils.java index b6272ce7135..be2dc6f3414 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SelectionKeyUtils.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SelectionKeyUtils.java @@ -26,28 +26,81 @@ public final class SelectionKeyUtils { private SelectionKeyUtils() {} + /** + * Adds an interest in writes for this channel while maintaining other interests. + * + * @param channel the channel + * @throws CancelledKeyException if the key was already cancelled + */ public static void setWriteInterested(NioChannel channel) throws CancelledKeyException { SelectionKey selectionKey = channel.getSelectionKey(); selectionKey.interestOps(selectionKey.interestOps() | SelectionKey.OP_WRITE); } + /** + * Removes an interest in writes for this channel while maintaining other interests. + * + * @param channel the channel + * @throws CancelledKeyException if the key was already cancelled + */ public static void removeWriteInterested(NioChannel channel) throws CancelledKeyException { SelectionKey selectionKey = channel.getSelectionKey(); selectionKey.interestOps(selectionKey.interestOps() & ~SelectionKey.OP_WRITE); } + /** + * Removes an interest in connects and reads for this channel while maintaining other interests. + * + * @param channel the channel + * @throws CancelledKeyException if the key was already cancelled + */ public static void setConnectAndReadInterested(NioChannel channel) throws CancelledKeyException { SelectionKey selectionKey = channel.getSelectionKey(); selectionKey.interestOps(selectionKey.interestOps() | SelectionKey.OP_CONNECT | SelectionKey.OP_READ); } + /** + * Removes an interest in connects, reads, and writes for this channel while maintaining other interests. + * + * @param channel the channel + * @throws CancelledKeyException if the key was already cancelled + */ + public static void setConnectReadAndWriteInterested(NioChannel channel) throws CancelledKeyException { + SelectionKey selectionKey = channel.getSelectionKey(); + selectionKey.interestOps(selectionKey.interestOps() | SelectionKey.OP_CONNECT | SelectionKey.OP_READ | SelectionKey.OP_WRITE); + } + + /** + * Removes an interest in connects for this channel while maintaining other interests. + * + * @param channel the channel + * @throws CancelledKeyException if the key was already cancelled + */ public static void removeConnectInterested(NioChannel channel) throws CancelledKeyException { SelectionKey selectionKey = channel.getSelectionKey(); selectionKey.interestOps(selectionKey.interestOps() & ~SelectionKey.OP_CONNECT); } - public static void setAcceptInterested(NioServerSocketChannel channel) { + /** + * Adds an interest in accepts for this channel while maintaining other interests. + * + * @param channel the channel + * @throws CancelledKeyException if the key was already cancelled + */ + public static void setAcceptInterested(NioServerSocketChannel channel) throws CancelledKeyException { SelectionKey selectionKey = channel.getSelectionKey(); selectionKey.interestOps(selectionKey.interestOps() | SelectionKey.OP_ACCEPT); } + + + /** + * Checks for an interest in writes for this channel. + * + * @param channel the channel + * @return a boolean indicating if we are currently interested in writes for this channel + * @throws CancelledKeyException if the key was already cancelled + */ + public static boolean isWriteInterested(NioSocketChannel channel) throws CancelledKeyException { + return (channel.getSelectionKey().interestOps() & SelectionKey.OP_WRITE) != 0; + } } diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketEventHandler.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketEventHandler.java index d3be18f3776..d5977cee851 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketEventHandler.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketEventHandler.java @@ -43,8 +43,14 @@ public class SocketEventHandler extends EventHandler { * * @param channel that was registered */ - protected void handleRegistration(NioSocketChannel channel) { - SelectionKeyUtils.setConnectAndReadInterested(channel); + protected void handleRegistration(NioSocketChannel channel) throws IOException { + ChannelContext context = channel.getContext(); + context.channelRegistered(); + if (context.hasQueuedWriteOps()) { + SelectionKeyUtils.setConnectReadAndWriteInterested(channel); + } else { + SelectionKeyUtils.setConnectAndReadInterested(channel); + } } /** @@ -86,10 +92,7 @@ public class SocketEventHandler extends EventHandler { * @param channel that can be read */ protected void handleRead(NioSocketChannel channel) throws IOException { - int bytesRead = channel.getReadContext().read(); - if (bytesRead == -1) { - handleClose(channel); - } + channel.getContext().read(); } /** @@ -107,16 +110,11 @@ public class SocketEventHandler extends EventHandler { * This method is called when a channel signals it is ready to receive writes. All of the write logic * should occur in this call. * - * @param channel that can be read + * @param channel that can be written to */ protected void handleWrite(NioSocketChannel channel) throws IOException { - WriteContext channelContext = channel.getWriteContext(); + ChannelContext channelContext = channel.getContext(); channelContext.flushChannel(); - if (channelContext.hasQueuedWriteOps()) { - SelectionKeyUtils.setWriteInterested(channel); - } else { - SelectionKeyUtils.removeWriteInterested(channel); - } } /** @@ -153,6 +151,23 @@ public class SocketEventHandler extends EventHandler { logger.warn(new ParameterizedMessage("exception while executing listener: {}", listener), exception); } + /** + * @param channel that was handled + */ + protected void postHandling(NioSocketChannel channel) { + if (channel.getContext().selectorShouldClose()) { + handleClose(channel); + } else { + boolean currentlyWriteInterested = SelectionKeyUtils.isWriteInterested(channel); + boolean pendingWrites = channel.getContext().hasQueuedWriteOps(); + if (currentlyWriteInterested == false && pendingWrites) { + SelectionKeyUtils.setWriteInterested(channel); + } else if (currentlyWriteInterested && pendingWrites == false) { + SelectionKeyUtils.removeWriteInterested(channel); + } + } + } + private void exceptionCaught(NioSocketChannel channel, Exception e) { channel.getExceptionContext().accept(channel, e); } diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketSelector.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketSelector.java index ac8ad87b726..e35aa7b4d22 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketSelector.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketSelector.java @@ -64,6 +64,8 @@ public class SocketSelector extends ESSelector { handleRead(nioSocketChannel); } } + + eventHandler.postHandling(nioSocketChannel); } @Override @@ -118,12 +120,12 @@ public class SocketSelector extends ESSelector { * @param writeOperation to be queued in a channel's buffer */ public void queueWriteInChannelBuffer(WriteOperation writeOperation) { - assert isOnCurrentThread() : "Must be on selector thread"; + assertOnSelectorThread(); NioSocketChannel channel = writeOperation.getChannel(); - WriteContext context = channel.getWriteContext(); + ChannelContext context = channel.getContext(); try { SelectionKeyUtils.setWriteInterested(channel); - context.queueWriteOperations(writeOperation); + context.queueWriteOperation(writeOperation); } catch (Exception e) { executeFailedListener(writeOperation.getListener(), e); } @@ -137,7 +139,7 @@ public class SocketSelector extends ESSelector { * @param value to provide to listener */ public void executeListener(BiConsumer listener, V value) { - assert isOnCurrentThread() : "Must be on selector thread"; + assertOnSelectorThread(); try { listener.accept(value, null); } catch (Exception e) { @@ -153,7 +155,7 @@ public class SocketSelector extends ESSelector { * @param exception to provide to listener */ public void executeFailedListener(BiConsumer listener, Exception exception) { - assert isOnCurrentThread() : "Must be on selector thread"; + assertOnSelectorThread(); try { listener.accept(null, exception); } catch (Exception e) { @@ -180,7 +182,7 @@ public class SocketSelector extends ESSelector { private void handleQueuedWrites() { WriteOperation writeOperation; while ((writeOperation = queuedWrites.poll()) != null) { - if (writeOperation.getChannel().isWritable()) { + if (writeOperation.getChannel().isOpen()) { queueWriteInChannelBuffer(writeOperation); } else { executeFailedListener(writeOperation.getListener(), new ClosedChannelException()); diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/WriteContext.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/WriteContext.java deleted file mode 100644 index 39e69e8f9a9..00000000000 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/WriteContext.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.nio; - -import java.io.IOException; -import java.util.function.BiConsumer; - -public interface WriteContext { - - void sendMessage(Object message, BiConsumer listener); - - void queueWriteOperations(WriteOperation writeOperation); - - void flushChannel() throws IOException; - - boolean hasQueuedWriteOps(); - - void clearQueuedWriteOps(Exception e); - -} diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/WriteOperation.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/WriteOperation.java index b6fcc838a96..09800d981bd 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/WriteOperation.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/WriteOperation.java @@ -19,74 +19,16 @@ package org.elasticsearch.nio; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.Arrays; import java.util.function.BiConsumer; -public class WriteOperation { +/** + * This is a basic write operation that can be queued with a channel. The only requirements of a write + * operation is that is has a listener and a reference to its channel. The actual conversion of the write + * operation implementation to bytes will be performed by the {@link ChannelContext}. + */ +public interface WriteOperation { - private final NioSocketChannel channel; - private final BiConsumer listener; - private final ByteBuffer[] buffers; - private final int[] offsets; - private final int length; - private int internalIndex; + BiConsumer getListener(); - public WriteOperation(NioSocketChannel channel, ByteBuffer[] buffers, BiConsumer listener) { - this.channel = channel; - this.listener = listener; - this.buffers = buffers; - this.offsets = new int[buffers.length]; - int offset = 0; - for (int i = 0; i < buffers.length; i++) { - ByteBuffer buffer = buffers[i]; - offsets[i] = offset; - offset += buffer.remaining(); - } - length = offset; - } - - public ByteBuffer[] getByteBuffers() { - return buffers; - } - - public BiConsumer getListener() { - return listener; - } - - public NioSocketChannel getChannel() { - return channel; - } - - public boolean isFullyFlushed() { - return internalIndex == length; - } - - public int flush() throws IOException { - int written = channel.write(getBuffersToWrite()); - internalIndex += written; - return written; - } - - private ByteBuffer[] getBuffersToWrite() { - int offsetIndex = getOffsetIndex(internalIndex); - - ByteBuffer[] postIndexBuffers = new ByteBuffer[buffers.length - offsetIndex]; - - ByteBuffer firstBuffer = buffers[offsetIndex].duplicate(); - firstBuffer.position(internalIndex - offsets[offsetIndex]); - postIndexBuffers[0] = firstBuffer; - int j = 1; - for (int i = (offsetIndex + 1); i < buffers.length; ++i) { - postIndexBuffers[j++] = buffers[i].duplicate(); - } - - return postIndexBuffers; - } - - private int getOffsetIndex(int offset) { - final int i = Arrays.binarySearch(offsets, offset); - return i < 0 ? (-(i + 1)) - 1 : i; - } + NioSocketChannel getChannel(); } diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/AcceptorEventHandlerTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/AcceptorEventHandlerTests.java index 9d8f47fe3ef..1f51fdc2017 100644 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/AcceptorEventHandlerTests.java +++ b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/AcceptorEventHandlerTests.java @@ -80,7 +80,7 @@ public class AcceptorEventHandlerTests extends ESTestCase { @SuppressWarnings("unchecked") public void testHandleAcceptCallsServerAcceptCallback() throws IOException { NioSocketChannel childChannel = new NioSocketChannel(mock(SocketChannel.class), socketSelector); - childChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class), mock(BiConsumer.class)); + childChannel.setContexts(mock(ChannelContext.class), mock(BiConsumer.class)); when(channelFactory.acceptNioChannel(same(channel), same(socketSelector))).thenReturn(childChannel); handler.acceptChannel(channel); diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesChannelContextTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesChannelContextTests.java new file mode 100644 index 00000000000..db0e6ae80ba --- /dev/null +++ b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesChannelContextTests.java @@ -0,0 +1,337 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.nio; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.test.ESTestCase; +import org.junit.Before; +import org.mockito.ArgumentCaptor; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ClosedChannelException; +import java.util.function.BiConsumer; +import java.util.function.Supplier; + +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.isNull; +import static org.mockito.Matchers.same; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class BytesChannelContextTests extends ESTestCase { + + private ChannelContext.ReadConsumer readConsumer; + private NioSocketChannel channel; + private BytesChannelContext context; + private InboundChannelBuffer channelBuffer; + private SocketSelector selector; + private BiConsumer listener; + private int messageLength; + + @Before + @SuppressWarnings("unchecked") + public void init() { + readConsumer = mock(ChannelContext.ReadConsumer.class); + + messageLength = randomInt(96) + 20; + selector = mock(SocketSelector.class); + listener = mock(BiConsumer.class); + channel = mock(NioSocketChannel.class); + Supplier pageSupplier = () -> + new InboundChannelBuffer.Page(ByteBuffer.allocate(BigArrays.BYTE_PAGE_SIZE), () -> {}); + channelBuffer = new InboundChannelBuffer(pageSupplier); + context = new BytesChannelContext(channel, readConsumer, channelBuffer); + + when(channel.getSelector()).thenReturn(selector); + when(selector.isOnCurrentThread()).thenReturn(true); + } + + public void testSuccessfulRead() throws IOException { + byte[] bytes = createMessage(messageLength); + + when(channel.read(any(ByteBuffer[].class))).thenAnswer(invocationOnMock -> { + ByteBuffer[] buffers = (ByteBuffer[]) invocationOnMock.getArguments()[0]; + buffers[0].put(bytes); + return bytes.length; + }); + + when(readConsumer.consumeReads(channelBuffer)).thenReturn(messageLength, 0); + + assertEquals(messageLength, context.read()); + + assertEquals(0, channelBuffer.getIndex()); + assertEquals(BigArrays.BYTE_PAGE_SIZE - bytes.length, channelBuffer.getCapacity()); + verify(readConsumer, times(1)).consumeReads(channelBuffer); + } + + public void testMultipleReadsConsumed() throws IOException { + byte[] bytes = createMessage(messageLength * 2); + + when(channel.read(any(ByteBuffer[].class))).thenAnswer(invocationOnMock -> { + ByteBuffer[] buffers = (ByteBuffer[]) invocationOnMock.getArguments()[0]; + buffers[0].put(bytes); + return bytes.length; + }); + + when(readConsumer.consumeReads(channelBuffer)).thenReturn(messageLength, messageLength, 0); + + assertEquals(bytes.length, context.read()); + + assertEquals(0, channelBuffer.getIndex()); + assertEquals(BigArrays.BYTE_PAGE_SIZE - bytes.length, channelBuffer.getCapacity()); + verify(readConsumer, times(2)).consumeReads(channelBuffer); + } + + public void testPartialRead() throws IOException { + byte[] bytes = createMessage(messageLength); + + when(channel.read(any(ByteBuffer[].class))).thenAnswer(invocationOnMock -> { + ByteBuffer[] buffers = (ByteBuffer[]) invocationOnMock.getArguments()[0]; + buffers[0].put(bytes); + return bytes.length; + }); + + + when(readConsumer.consumeReads(channelBuffer)).thenReturn(0); + + assertEquals(messageLength, context.read()); + + assertEquals(bytes.length, channelBuffer.getIndex()); + verify(readConsumer, times(1)).consumeReads(channelBuffer); + + when(readConsumer.consumeReads(channelBuffer)).thenReturn(messageLength * 2, 0); + + assertEquals(messageLength, context.read()); + + assertEquals(0, channelBuffer.getIndex()); + assertEquals(BigArrays.BYTE_PAGE_SIZE - (bytes.length * 2), channelBuffer.getCapacity()); + verify(readConsumer, times(2)).consumeReads(channelBuffer); + } + + public void testReadThrowsIOException() throws IOException { + IOException ioException = new IOException(); + when(channel.read(any(ByteBuffer[].class))).thenThrow(ioException); + + IOException ex = expectThrows(IOException.class, () -> context.read()); + assertSame(ioException, ex); + } + + public void testReadThrowsIOExceptionMeansReadyForClose() throws IOException { + when(channel.read(any(ByteBuffer[].class))).thenThrow(new IOException()); + + assertFalse(context.selectorShouldClose()); + expectThrows(IOException.class, () -> context.read()); + assertTrue(context.selectorShouldClose()); + } + + public void testReadLessThanZeroMeansReadyForClose() throws IOException { + when(channel.read(any(ByteBuffer[].class))).thenReturn(-1); + + assertEquals(0, context.read()); + + assertTrue(context.selectorShouldClose()); + } + + public void testCloseClosesChannelBuffer() throws IOException { + Runnable closer = mock(Runnable.class); + Supplier pageSupplier = () -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14), closer); + InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier); + buffer.ensureCapacity(1); + BytesChannelContext context = new BytesChannelContext(channel, readConsumer, buffer); + context.closeFromSelector(); + verify(closer).run(); + } + + public void testWriteFailsIfClosing() { + context.closeChannel(); + + ByteBuffer[] buffers = {ByteBuffer.wrap(createMessage(10))}; + context.sendMessage(buffers, listener); + + verify(listener).accept(isNull(Void.class), any(ClosedChannelException.class)); + } + + public void testSendMessageFromDifferentThreadIsQueuedWithSelector() throws Exception { + ArgumentCaptor writeOpCaptor = ArgumentCaptor.forClass(BytesWriteOperation.class); + + when(selector.isOnCurrentThread()).thenReturn(false); + + ByteBuffer[] buffers = {ByteBuffer.wrap(createMessage(10))}; + context.sendMessage(buffers, listener); + + verify(selector).queueWrite(writeOpCaptor.capture()); + BytesWriteOperation writeOp = writeOpCaptor.getValue(); + + assertSame(listener, writeOp.getListener()); + assertSame(channel, writeOp.getChannel()); + assertEquals(buffers[0], writeOp.getBuffersToWrite()[0]); + } + + public void testSendMessageFromSameThreadIsQueuedInChannel() { + ArgumentCaptor writeOpCaptor = ArgumentCaptor.forClass(BytesWriteOperation.class); + + ByteBuffer[] buffers = {ByteBuffer.wrap(createMessage(10))}; + context.sendMessage(buffers, listener); + + verify(selector).queueWriteInChannelBuffer(writeOpCaptor.capture()); + BytesWriteOperation writeOp = writeOpCaptor.getValue(); + + assertSame(listener, writeOp.getListener()); + assertSame(channel, writeOp.getChannel()); + assertEquals(buffers[0], writeOp.getBuffersToWrite()[0]); + } + + public void testWriteIsQueuedInChannel() { + assertFalse(context.hasQueuedWriteOps()); + + ByteBuffer[] buffer = {ByteBuffer.allocate(10)}; + context.queueWriteOperation(new BytesWriteOperation(channel, buffer, listener)); + + assertTrue(context.hasQueuedWriteOps()); + } + + public void testWriteOpsClearedOnClose() throws Exception { + assertFalse(context.hasQueuedWriteOps()); + + ByteBuffer[] buffer = {ByteBuffer.allocate(10)}; + context.queueWriteOperation(new BytesWriteOperation(channel, buffer, listener)); + + assertTrue(context.hasQueuedWriteOps()); + + context.closeFromSelector(); + + verify(selector).executeFailedListener(same(listener), any(ClosedChannelException.class)); + + assertFalse(context.hasQueuedWriteOps()); + } + + public void testQueuedWriteIsFlushedInFlushCall() throws Exception { + assertFalse(context.hasQueuedWriteOps()); + + ByteBuffer[] buffers = {ByteBuffer.allocate(10)}; + BytesWriteOperation writeOperation = mock(BytesWriteOperation.class); + context.queueWriteOperation(writeOperation); + + assertTrue(context.hasQueuedWriteOps()); + + when(writeOperation.getBuffersToWrite()).thenReturn(buffers); + when(writeOperation.isFullyFlushed()).thenReturn(true); + when(writeOperation.getListener()).thenReturn(listener); + context.flushChannel(); + + verify(channel).write(buffers); + verify(selector).executeListener(listener, null); + assertFalse(context.hasQueuedWriteOps()); + } + + public void testPartialFlush() throws IOException { + assertFalse(context.hasQueuedWriteOps()); + + BytesWriteOperation writeOperation = mock(BytesWriteOperation.class); + context.queueWriteOperation(writeOperation); + + assertTrue(context.hasQueuedWriteOps()); + + when(writeOperation.isFullyFlushed()).thenReturn(false); + context.flushChannel(); + + verify(listener, times(0)).accept(null, null); + assertTrue(context.hasQueuedWriteOps()); + } + + @SuppressWarnings("unchecked") + public void testMultipleWritesPartialFlushes() throws IOException { + assertFalse(context.hasQueuedWriteOps()); + + BiConsumer listener2 = mock(BiConsumer.class); + BytesWriteOperation writeOperation1 = mock(BytesWriteOperation.class); + BytesWriteOperation writeOperation2 = mock(BytesWriteOperation.class); + when(writeOperation1.getListener()).thenReturn(listener); + when(writeOperation2.getListener()).thenReturn(listener2); + context.queueWriteOperation(writeOperation1); + context.queueWriteOperation(writeOperation2); + + assertTrue(context.hasQueuedWriteOps()); + + when(writeOperation1.isFullyFlushed()).thenReturn(true); + when(writeOperation2.isFullyFlushed()).thenReturn(false); + context.flushChannel(); + + verify(selector).executeListener(listener, null); + verify(listener2, times(0)).accept(null, null); + assertTrue(context.hasQueuedWriteOps()); + + when(writeOperation2.isFullyFlushed()).thenReturn(true); + + context.flushChannel(); + + verify(selector).executeListener(listener2, null); + assertFalse(context.hasQueuedWriteOps()); + } + + public void testWhenIOExceptionThrownListenerIsCalled() throws IOException { + assertFalse(context.hasQueuedWriteOps()); + + ByteBuffer[] buffers = {ByteBuffer.allocate(10)}; + BytesWriteOperation writeOperation = mock(BytesWriteOperation.class); + context.queueWriteOperation(writeOperation); + + assertTrue(context.hasQueuedWriteOps()); + + IOException exception = new IOException(); + when(writeOperation.getBuffersToWrite()).thenReturn(buffers); + when(channel.write(buffers)).thenThrow(exception); + when(writeOperation.getListener()).thenReturn(listener); + expectThrows(IOException.class, () -> context.flushChannel()); + + verify(selector).executeFailedListener(listener, exception); + assertFalse(context.hasQueuedWriteOps()); + } + + public void testWriteIOExceptionMeansChannelReadyToClose() throws IOException { + ByteBuffer[] buffers = {ByteBuffer.allocate(10)}; + BytesWriteOperation writeOperation = mock(BytesWriteOperation.class); + context.queueWriteOperation(writeOperation); + + IOException exception = new IOException(); + when(writeOperation.getBuffersToWrite()).thenReturn(buffers); + when(channel.write(buffers)).thenThrow(exception); + + assertFalse(context.selectorShouldClose()); + expectThrows(IOException.class, () -> context.flushChannel()); + assertTrue(context.selectorShouldClose()); + } + + public void initiateCloseSchedulesCloseWithSelector() { + context.closeChannel(); + verify(selector).queueChannelClose(channel); + } + + private static byte[] createMessage(int length) { + byte[] bytes = new byte[length]; + for (int i = 0; i < length; ++i) { + bytes[i] = randomByte(); + } + return bytes; + } +} diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesReadContextTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesReadContextTests.java deleted file mode 100644 index 69f187378ac..00000000000 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesReadContextTests.java +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.nio; - -import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.test.ESTestCase; -import org.junit.Before; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.function.Supplier; - -import static org.mockito.Matchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -public class BytesReadContextTests extends ESTestCase { - - private ReadContext.ReadConsumer readConsumer; - private NioSocketChannel channel; - private BytesReadContext readContext; - private InboundChannelBuffer channelBuffer; - private int messageLength; - - @Before - public void init() { - readConsumer = mock(ReadContext.ReadConsumer.class); - - messageLength = randomInt(96) + 20; - channel = mock(NioSocketChannel.class); - Supplier pageSupplier = () -> - new InboundChannelBuffer.Page(ByteBuffer.allocate(BigArrays.BYTE_PAGE_SIZE), () -> {}); - channelBuffer = new InboundChannelBuffer(pageSupplier); - readContext = new BytesReadContext(channel, readConsumer, channelBuffer); - } - - public void testSuccessfulRead() throws IOException { - byte[] bytes = createMessage(messageLength); - - when(channel.read(any(ByteBuffer[].class))).thenAnswer(invocationOnMock -> { - ByteBuffer[] buffers = (ByteBuffer[]) invocationOnMock.getArguments()[0]; - buffers[0].put(bytes); - return bytes.length; - }); - - when(readConsumer.consumeReads(channelBuffer)).thenReturn(messageLength, 0); - - assertEquals(messageLength, readContext.read()); - - assertEquals(0, channelBuffer.getIndex()); - assertEquals(BigArrays.BYTE_PAGE_SIZE - bytes.length, channelBuffer.getCapacity()); - verify(readConsumer, times(2)).consumeReads(channelBuffer); - } - - public void testMultipleReadsConsumed() throws IOException { - byte[] bytes = createMessage(messageLength * 2); - - when(channel.read(any(ByteBuffer[].class))).thenAnswer(invocationOnMock -> { - ByteBuffer[] buffers = (ByteBuffer[]) invocationOnMock.getArguments()[0]; - buffers[0].put(bytes); - return bytes.length; - }); - - when(readConsumer.consumeReads(channelBuffer)).thenReturn(messageLength, messageLength, 0); - - assertEquals(bytes.length, readContext.read()); - - assertEquals(0, channelBuffer.getIndex()); - assertEquals(BigArrays.BYTE_PAGE_SIZE - bytes.length, channelBuffer.getCapacity()); - verify(readConsumer, times(3)).consumeReads(channelBuffer); - } - - public void testPartialRead() throws IOException { - byte[] bytes = createMessage(messageLength); - - when(channel.read(any(ByteBuffer[].class))).thenAnswer(invocationOnMock -> { - ByteBuffer[] buffers = (ByteBuffer[]) invocationOnMock.getArguments()[0]; - buffers[0].put(bytes); - return bytes.length; - }); - - - when(readConsumer.consumeReads(channelBuffer)).thenReturn(0, messageLength); - - assertEquals(messageLength, readContext.read()); - - assertEquals(bytes.length, channelBuffer.getIndex()); - verify(readConsumer, times(1)).consumeReads(channelBuffer); - - when(readConsumer.consumeReads(channelBuffer)).thenReturn(messageLength * 2, 0); - - assertEquals(messageLength, readContext.read()); - - assertEquals(0, channelBuffer.getIndex()); - assertEquals(BigArrays.BYTE_PAGE_SIZE - (bytes.length * 2), channelBuffer.getCapacity()); - verify(readConsumer, times(3)).consumeReads(channelBuffer); - } - - public void testReadThrowsIOException() throws IOException { - IOException ioException = new IOException(); - when(channel.read(any(ByteBuffer[].class))).thenThrow(ioException); - - IOException ex = expectThrows(IOException.class, () -> readContext.read()); - assertSame(ioException, ex); - } - - public void closeClosesChannelBuffer() { - InboundChannelBuffer buffer = mock(InboundChannelBuffer.class); - BytesReadContext readContext = new BytesReadContext(channel, readConsumer, buffer); - - readContext.close(); - - verify(buffer).close(); - } - - private static byte[] createMessage(int length) { - byte[] bytes = new byte[length]; - for (int i = 0; i < length; ++i) { - bytes[i] = randomByte(); - } - return bytes; - } -} diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesWriteContextTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesWriteContextTests.java deleted file mode 100644 index 9d5b1c92cb6..00000000000 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesWriteContextTests.java +++ /dev/null @@ -1,212 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.nio; - -import org.elasticsearch.test.ESTestCase; -import org.junit.Before; -import org.mockito.ArgumentCaptor; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.channels.ClosedChannelException; -import java.util.function.BiConsumer; - -import static org.mockito.Matchers.any; -import static org.mockito.Matchers.isNull; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -public class BytesWriteContextTests extends ESTestCase { - - private SocketSelector selector; - private BiConsumer listener; - private BytesWriteContext writeContext; - private NioSocketChannel channel; - - @Before - @SuppressWarnings("unchecked") - public void setUp() throws Exception { - super.setUp(); - selector = mock(SocketSelector.class); - listener = mock(BiConsumer.class); - channel = mock(NioSocketChannel.class); - writeContext = new BytesWriteContext(channel); - - when(channel.getSelector()).thenReturn(selector); - when(selector.isOnCurrentThread()).thenReturn(true); - } - - public void testWriteFailsIfChannelNotWritable() throws Exception { - when(channel.isWritable()).thenReturn(false); - - ByteBuffer[] buffers = {ByteBuffer.wrap(generateBytes(10))}; - writeContext.sendMessage(buffers, listener); - - verify(listener).accept(isNull(Void.class), any(ClosedChannelException.class)); - } - - public void testSendMessageFromDifferentThreadIsQueuedWithSelector() throws Exception { - ArgumentCaptor writeOpCaptor = ArgumentCaptor.forClass(WriteOperation.class); - - when(selector.isOnCurrentThread()).thenReturn(false); - when(channel.isWritable()).thenReturn(true); - - ByteBuffer[] buffers = {ByteBuffer.wrap(generateBytes(10))}; - writeContext.sendMessage(buffers, listener); - - verify(selector).queueWrite(writeOpCaptor.capture()); - WriteOperation writeOp = writeOpCaptor.getValue(); - - assertSame(listener, writeOp.getListener()); - assertSame(channel, writeOp.getChannel()); - assertEquals(buffers[0], writeOp.getByteBuffers()[0]); - } - - public void testSendMessageFromSameThreadIsQueuedInChannel() throws Exception { - ArgumentCaptor writeOpCaptor = ArgumentCaptor.forClass(WriteOperation.class); - - when(channel.isWritable()).thenReturn(true); - - ByteBuffer[] buffers = {ByteBuffer.wrap(generateBytes(10))}; - writeContext.sendMessage(buffers, listener); - - verify(selector).queueWriteInChannelBuffer(writeOpCaptor.capture()); - WriteOperation writeOp = writeOpCaptor.getValue(); - - assertSame(listener, writeOp.getListener()); - assertSame(channel, writeOp.getChannel()); - assertEquals(buffers[0], writeOp.getByteBuffers()[0]); - } - - public void testWriteIsQueuedInChannel() throws Exception { - assertFalse(writeContext.hasQueuedWriteOps()); - - ByteBuffer[] buffer = {ByteBuffer.allocate(10)}; - writeContext.queueWriteOperations(new WriteOperation(channel, buffer, listener)); - - assertTrue(writeContext.hasQueuedWriteOps()); - } - - public void testWriteOpsCanBeCleared() throws Exception { - assertFalse(writeContext.hasQueuedWriteOps()); - - ByteBuffer[] buffer = {ByteBuffer.allocate(10)}; - writeContext.queueWriteOperations(new WriteOperation(channel, buffer, listener)); - - assertTrue(writeContext.hasQueuedWriteOps()); - - ClosedChannelException e = new ClosedChannelException(); - writeContext.clearQueuedWriteOps(e); - - verify(selector).executeFailedListener(listener, e); - - assertFalse(writeContext.hasQueuedWriteOps()); - } - - public void testQueuedWriteIsFlushedInFlushCall() throws Exception { - assertFalse(writeContext.hasQueuedWriteOps()); - - WriteOperation writeOperation = mock(WriteOperation.class); - writeContext.queueWriteOperations(writeOperation); - - assertTrue(writeContext.hasQueuedWriteOps()); - - when(writeOperation.isFullyFlushed()).thenReturn(true); - when(writeOperation.getListener()).thenReturn(listener); - writeContext.flushChannel(); - - verify(writeOperation).flush(); - verify(selector).executeListener(listener, null); - assertFalse(writeContext.hasQueuedWriteOps()); - } - - public void testPartialFlush() throws IOException { - assertFalse(writeContext.hasQueuedWriteOps()); - - WriteOperation writeOperation = mock(WriteOperation.class); - writeContext.queueWriteOperations(writeOperation); - - assertTrue(writeContext.hasQueuedWriteOps()); - - when(writeOperation.isFullyFlushed()).thenReturn(false); - writeContext.flushChannel(); - - verify(listener, times(0)).accept(null, null); - assertTrue(writeContext.hasQueuedWriteOps()); - } - - @SuppressWarnings("unchecked") - public void testMultipleWritesPartialFlushes() throws IOException { - assertFalse(writeContext.hasQueuedWriteOps()); - - BiConsumer listener2 = mock(BiConsumer.class); - WriteOperation writeOperation1 = mock(WriteOperation.class); - WriteOperation writeOperation2 = mock(WriteOperation.class); - when(writeOperation1.getListener()).thenReturn(listener); - when(writeOperation2.getListener()).thenReturn(listener2); - writeContext.queueWriteOperations(writeOperation1); - writeContext.queueWriteOperations(writeOperation2); - - assertTrue(writeContext.hasQueuedWriteOps()); - - when(writeOperation1.isFullyFlushed()).thenReturn(true); - when(writeOperation2.isFullyFlushed()).thenReturn(false); - writeContext.flushChannel(); - - verify(selector).executeListener(listener, null); - verify(listener2, times(0)).accept(null, null); - assertTrue(writeContext.hasQueuedWriteOps()); - - when(writeOperation2.isFullyFlushed()).thenReturn(true); - - writeContext.flushChannel(); - - verify(selector).executeListener(listener2, null); - assertFalse(writeContext.hasQueuedWriteOps()); - } - - public void testWhenIOExceptionThrownListenerIsCalled() throws IOException { - assertFalse(writeContext.hasQueuedWriteOps()); - - WriteOperation writeOperation = mock(WriteOperation.class); - writeContext.queueWriteOperations(writeOperation); - - assertTrue(writeContext.hasQueuedWriteOps()); - - IOException exception = new IOException(); - when(writeOperation.flush()).thenThrow(exception); - when(writeOperation.getListener()).thenReturn(listener); - expectThrows(IOException.class, () -> writeContext.flushChannel()); - - verify(selector).executeFailedListener(listener, exception); - assertFalse(writeContext.hasQueuedWriteOps()); - } - - private byte[] generateBytes(int n) { - n += 10; - byte[] bytes = new byte[n]; - for (int i = 0; i < n; ++i) { - bytes[i] = randomByte(); - } - return bytes; - } -} diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/ChannelFactoryTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/ChannelFactoryTests.java index c1183af4e5b..e3f42139fd8 100644 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/ChannelFactoryTests.java +++ b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/ChannelFactoryTests.java @@ -139,7 +139,7 @@ public class ChannelFactoryTests extends ESTestCase { @Override public NioSocketChannel createChannel(SocketSelector selector, SocketChannel channel) throws IOException { NioSocketChannel nioSocketChannel = new NioSocketChannel(channel, selector); - nioSocketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class), mock(BiConsumer.class)); + nioSocketChannel.setContexts(mock(ChannelContext.class), mock(BiConsumer.class)); return nioSocketChannel; } diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/NioServerSocketChannelTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/NioServerSocketChannelTests.java index 713f01ec283..4f4673140fd 100644 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/NioServerSocketChannelTests.java +++ b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/NioServerSocketChannelTests.java @@ -82,7 +82,7 @@ public class NioServerSocketChannelTests extends ESTestCase { PlainActionFuture closeFuture = PlainActionFuture.newFuture(); channel.addCloseListener(ActionListener.toBiConsumer(closeFuture)); - channel.close(); + selector.queueChannelClose(channel); closeFuture.actionGet(); diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/NioSocketChannelTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/NioSocketChannelTests.java index 6a32b11f18b..dd0956458fa 100644 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/NioSocketChannelTests.java +++ b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/NioSocketChannelTests.java @@ -35,6 +35,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.BiConsumer; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -66,7 +67,7 @@ public class NioSocketChannelTests extends ESTestCase { CountDownLatch latch = new CountDownLatch(1); NioSocketChannel socketChannel = new DoNotCloseChannel(mock(SocketChannel.class), selector); - socketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class), mock(BiConsumer.class)); + socketChannel.setContexts(mock(ChannelContext.class), mock(BiConsumer.class)); socketChannel.addCloseListener(ActionListener.toBiConsumer(new ActionListener() { @Override public void onResponse(Void o) { @@ -86,7 +87,45 @@ public class NioSocketChannelTests extends ESTestCase { PlainActionFuture closeFuture = PlainActionFuture.newFuture(); socketChannel.addCloseListener(ActionListener.toBiConsumer(closeFuture)); - socketChannel.close(); + selector.queueChannelClose(socketChannel); + closeFuture.actionGet(); + + assertTrue(closedRawChannel.get()); + assertFalse(socketChannel.isOpen()); + latch.await(); + assertTrue(isClosed.get()); + } + + @SuppressWarnings("unchecked") + public void testCloseContextExceptionDoesNotStopClose() throws Exception { + AtomicBoolean isClosed = new AtomicBoolean(false); + CountDownLatch latch = new CountDownLatch(1); + + IOException ioException = new IOException(); + NioSocketChannel socketChannel = new DoNotCloseChannel(mock(SocketChannel.class), selector); + ChannelContext context = mock(ChannelContext.class); + doThrow(ioException).when(context).closeFromSelector(); + socketChannel.setContexts(context, mock(BiConsumer.class)); + socketChannel.addCloseListener(ActionListener.toBiConsumer(new ActionListener() { + @Override + public void onResponse(Void o) { + isClosed.set(true); + latch.countDown(); + } + @Override + public void onFailure(Exception e) { + isClosed.set(true); + latch.countDown(); + } + })); + + assertTrue(socketChannel.isOpen()); + assertFalse(closedRawChannel.get()); + assertFalse(isClosed.get()); + + PlainActionFuture closeFuture = PlainActionFuture.newFuture(); + socketChannel.addCloseListener(ActionListener.toBiConsumer(closeFuture)); + selector.queueChannelClose(socketChannel); closeFuture.actionGet(); assertTrue(closedRawChannel.get()); @@ -100,7 +139,7 @@ public class NioSocketChannelTests extends ESTestCase { SocketChannel rawChannel = mock(SocketChannel.class); when(rawChannel.finishConnect()).thenReturn(true); NioSocketChannel socketChannel = new DoNotCloseChannel(rawChannel, selector); - socketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class), mock(BiConsumer.class)); + socketChannel.setContexts(mock(ChannelContext.class), mock(BiConsumer.class)); selector.scheduleForRegistration(socketChannel); PlainActionFuture connectFuture = PlainActionFuture.newFuture(); @@ -117,7 +156,7 @@ public class NioSocketChannelTests extends ESTestCase { SocketChannel rawChannel = mock(SocketChannel.class); when(rawChannel.finishConnect()).thenThrow(new ConnectException()); NioSocketChannel socketChannel = new DoNotCloseChannel(rawChannel, selector); - socketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class), mock(BiConsumer.class)); + socketChannel.setContexts(mock(ChannelContext.class), mock(BiConsumer.class)); selector.scheduleForRegistration(socketChannel); PlainActionFuture connectFuture = PlainActionFuture.newFuture(); diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketEventHandlerTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketEventHandlerTests.java index 2898cf18d5b..e0f833c9051 100644 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketEventHandlerTests.java +++ b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketEventHandlerTests.java @@ -28,8 +28,10 @@ import java.nio.channels.CancelledKeyException; import java.nio.channels.SelectionKey; import java.nio.channels.SocketChannel; import java.util.function.BiConsumer; +import java.util.function.Supplier; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -39,7 +41,6 @@ public class SocketEventHandlerTests extends ESTestCase { private SocketEventHandler handler; private NioSocketChannel channel; - private ReadContext readContext; private SocketChannel rawChannel; @Before @@ -50,21 +51,37 @@ public class SocketEventHandlerTests extends ESTestCase { handler = new SocketEventHandler(logger); rawChannel = mock(SocketChannel.class); channel = new DoNotRegisterChannel(rawChannel, socketSelector); - readContext = mock(ReadContext.class); when(rawChannel.finishConnect()).thenReturn(true); - channel.setContexts(readContext, new BytesWriteContext(channel), exceptionHandler); + Supplier pageSupplier = () -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14), () -> {}); + InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier); + channel.setContexts(new BytesChannelContext(channel, mock(ChannelContext.ReadConsumer.class), buffer), exceptionHandler); channel.register(); channel.finishConnect(); when(socketSelector.isOnCurrentThread()).thenReturn(true); } + public void testRegisterCallsContext() throws IOException { + NioSocketChannel channel = mock(NioSocketChannel.class); + ChannelContext channelContext = mock(ChannelContext.class); + when(channel.getContext()).thenReturn(channelContext); + when(channel.getSelectionKey()).thenReturn(new TestSelectionKey(0)); + handler.handleRegistration(channel); + verify(channelContext).channelRegistered(); + } + public void testRegisterAddsOP_CONNECTAndOP_READInterest() throws IOException { handler.handleRegistration(channel); assertEquals(SelectionKey.OP_READ | SelectionKey.OP_CONNECT, channel.getSelectionKey().interestOps()); } + public void testRegisterWithPendingWritesAddsOP_CONNECTAndOP_READAndOP_WRITEInterest() throws IOException { + channel.getContext().queueWriteOperation(mock(BytesWriteOperation.class)); + handler.handleRegistration(channel); + assertEquals(SelectionKey.OP_READ | SelectionKey.OP_CONNECT | SelectionKey.OP_WRITE, channel.getSelectionKey().interestOps()); + } + public void testRegistrationExceptionCallsExceptionHandler() throws IOException { CancelledKeyException exception = new CancelledKeyException(); handler.registrationException(channel, exception); @@ -83,68 +100,76 @@ public class SocketEventHandlerTests extends ESTestCase { verify(exceptionHandler).accept(channel, exception); } - public void testHandleReadDelegatesToReadContext() throws IOException { - when(readContext.read()).thenReturn(1); + public void testHandleReadDelegatesToContext() throws IOException { + NioSocketChannel channel = new DoNotRegisterChannel(rawChannel, mock(SocketSelector.class)); + ChannelContext context = mock(ChannelContext.class); + channel.setContexts(context, exceptionHandler); + when(context.read()).thenReturn(1); handler.handleRead(channel); - - verify(readContext).read(); + verify(context).read(); } - public void testHandleReadMarksChannelForCloseIfPeerClosed() throws IOException { - NioSocketChannel nioSocketChannel = mock(NioSocketChannel.class); - when(nioSocketChannel.getReadContext()).thenReturn(readContext); - when(readContext.read()).thenReturn(-1); - - handler.handleRead(nioSocketChannel); - - verify(nioSocketChannel).closeFromSelector(); - } - - public void testReadExceptionCallsExceptionHandler() throws IOException { + public void testReadExceptionCallsExceptionHandler() { IOException exception = new IOException(); handler.readException(channel, exception); verify(exceptionHandler).accept(channel, exception); } - @SuppressWarnings("unchecked") - public void testHandleWriteWithCompleteFlushRemovesOP_WRITEInterest() throws IOException { - SelectionKey selectionKey = channel.getSelectionKey(); - setWriteAndRead(channel); - assertEquals(SelectionKey.OP_READ | SelectionKey.OP_WRITE, selectionKey.interestOps()); - - ByteBuffer[] buffers = {ByteBuffer.allocate(1)}; - channel.getWriteContext().queueWriteOperations(new WriteOperation(channel, buffers, mock(BiConsumer.class))); - - when(rawChannel.write(buffers[0])).thenReturn(1); - handler.handleWrite(channel); - - assertEquals(SelectionKey.OP_READ, selectionKey.interestOps()); + public void testWriteExceptionCallsExceptionHandler() { + IOException exception = new IOException(); + handler.writeException(channel, exception); + verify(exceptionHandler).accept(channel, exception); } - @SuppressWarnings("unchecked") - public void testHandleWriteWithInCompleteFlushLeavesOP_WRITEInterest() throws IOException { - SelectionKey selectionKey = channel.getSelectionKey(); - setWriteAndRead(channel); - assertEquals(SelectionKey.OP_READ | SelectionKey.OP_WRITE, selectionKey.interestOps()); + public void testPostHandlingCallWillCloseTheChannelIfReady() throws IOException { + NioSocketChannel channel = mock(NioSocketChannel.class); + ChannelContext context = mock(ChannelContext.class); + when(channel.getSelectionKey()).thenReturn(new TestSelectionKey(0)); - ByteBuffer[] buffers = {ByteBuffer.allocate(1)}; - channel.getWriteContext().queueWriteOperations(new WriteOperation(channel, buffers, mock(BiConsumer.class))); + when(channel.getContext()).thenReturn(context); + when(context.selectorShouldClose()).thenReturn(true); + handler.postHandling(channel); - when(rawChannel.write(buffers[0])).thenReturn(0); - handler.handleWrite(channel); - - assertEquals(SelectionKey.OP_READ | SelectionKey.OP_WRITE, selectionKey.interestOps()); + verify(channel).closeFromSelector(); } - public void testHandleWriteWithNoOpsRemovesOP_WRITEInterest() throws IOException { - SelectionKey selectionKey = channel.getSelectionKey(); - setWriteAndRead(channel); + public void testPostHandlingCallWillNotCloseTheChannelIfNotReady() throws IOException { + NioSocketChannel channel = mock(NioSocketChannel.class); + ChannelContext context = mock(ChannelContext.class); + when(channel.getSelectionKey()).thenReturn(new TestSelectionKey(0)); + + when(channel.getContext()).thenReturn(context); + when(context.selectorShouldClose()).thenReturn(false); + handler.postHandling(channel); + + verify(channel, times(0)).closeFromSelector(); + } + + public void testPostHandlingWillAddWriteIfNecessary() throws IOException { + NioSocketChannel channel = new DoNotRegisterChannel(rawChannel, mock(SocketSelector.class)); + channel.setSelectionKey(new TestSelectionKey(SelectionKey.OP_READ)); + ChannelContext context = mock(ChannelContext.class); + channel.setContexts(context, null); + + when(context.hasQueuedWriteOps()).thenReturn(true); + + assertEquals(SelectionKey.OP_READ, channel.getSelectionKey().interestOps()); + handler.postHandling(channel); assertEquals(SelectionKey.OP_READ | SelectionKey.OP_WRITE, channel.getSelectionKey().interestOps()); + } - handler.handleWrite(channel); + public void testPostHandlingWillRemoveWriteIfNecessary() throws IOException { + NioSocketChannel channel = new DoNotRegisterChannel(rawChannel, mock(SocketSelector.class)); + channel.setSelectionKey(new TestSelectionKey(SelectionKey.OP_READ | SelectionKey.OP_WRITE)); + ChannelContext context = mock(ChannelContext.class); + channel.setContexts(context, null); - assertEquals(SelectionKey.OP_READ, selectionKey.interestOps()); + when(context.hasQueuedWriteOps()).thenReturn(false); + + assertEquals(SelectionKey.OP_READ | SelectionKey.OP_WRITE, channel.getSelectionKey().interestOps()); + handler.postHandling(channel); + assertEquals(SelectionKey.OP_READ, channel.getSelectionKey().interestOps()); } private void setWriteAndRead(NioChannel channel) { @@ -152,10 +177,4 @@ public class SocketEventHandlerTests extends ESTestCase { SelectionKeyUtils.removeConnectInterested(channel); SelectionKeyUtils.setWriteInterested(channel); } - - public void testWriteExceptionCallsExceptionHandler() throws IOException { - IOException exception = new IOException(); - handler.writeException(channel, exception); - verify(exceptionHandler).accept(channel, exception); - } } diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketSelectorTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketSelectorTests.java index e50da352623..9197fe38dbc 100644 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketSelectorTests.java +++ b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketSelectorTests.java @@ -49,7 +49,7 @@ public class SocketSelectorTests extends ESTestCase { private SocketEventHandler eventHandler; private NioSocketChannel channel; private TestSelectionKey selectionKey; - private WriteContext writeContext; + private ChannelContext channelContext; private BiConsumer listener; private ByteBuffer[] buffers = {ByteBuffer.allocate(1)}; private Selector rawSelector; @@ -60,7 +60,7 @@ public class SocketSelectorTests extends ESTestCase { super.setUp(); eventHandler = mock(SocketEventHandler.class); channel = mock(NioSocketChannel.class); - writeContext = mock(WriteContext.class); + channelContext = mock(ChannelContext.class); listener = mock(BiConsumer.class); selectionKey = new TestSelectionKey(0); selectionKey.attach(channel); @@ -71,7 +71,7 @@ public class SocketSelectorTests extends ESTestCase { when(channel.isOpen()).thenReturn(true); when(channel.getSelectionKey()).thenReturn(selectionKey); - when(channel.getWriteContext()).thenReturn(writeContext); + when(channel.getContext()).thenReturn(channelContext); when(channel.isConnectComplete()).thenReturn(true); when(channel.getSelector()).thenReturn(socketSelector); } @@ -129,75 +129,71 @@ public class SocketSelectorTests extends ESTestCase { public void testQueueWriteWhenNotRunning() throws Exception { socketSelector.close(); - socketSelector.queueWrite(new WriteOperation(channel, buffers, listener)); + socketSelector.queueWrite(new BytesWriteOperation(channel, buffers, listener)); verify(listener).accept(isNull(Void.class), any(ClosedSelectorException.class)); } - public void testQueueWriteChannelIsNoLongerWritable() throws Exception { - WriteOperation writeOperation = new WriteOperation(channel, buffers, listener); + public void testQueueWriteChannelIsClosed() throws Exception { + BytesWriteOperation writeOperation = new BytesWriteOperation(channel, buffers, listener); socketSelector.queueWrite(writeOperation); - when(channel.isWritable()).thenReturn(false); + when(channel.isOpen()).thenReturn(false); socketSelector.preSelect(); - verify(writeContext, times(0)).queueWriteOperations(writeOperation); + verify(channelContext, times(0)).queueWriteOperation(writeOperation); verify(listener).accept(isNull(Void.class), any(ClosedChannelException.class)); } public void testQueueWriteSelectionKeyThrowsException() throws Exception { SelectionKey selectionKey = mock(SelectionKey.class); - WriteOperation writeOperation = new WriteOperation(channel, buffers, listener); + BytesWriteOperation writeOperation = new BytesWriteOperation(channel, buffers, listener); CancelledKeyException cancelledKeyException = new CancelledKeyException(); socketSelector.queueWrite(writeOperation); - when(channel.isWritable()).thenReturn(true); when(channel.getSelectionKey()).thenReturn(selectionKey); when(selectionKey.interestOps(anyInt())).thenThrow(cancelledKeyException); socketSelector.preSelect(); - verify(writeContext, times(0)).queueWriteOperations(writeOperation); + verify(channelContext, times(0)).queueWriteOperation(writeOperation); verify(listener).accept(null, cancelledKeyException); } public void testQueueWriteSuccessful() throws Exception { - WriteOperation writeOperation = new WriteOperation(channel, buffers, listener); + BytesWriteOperation writeOperation = new BytesWriteOperation(channel, buffers, listener); socketSelector.queueWrite(writeOperation); assertTrue((selectionKey.interestOps() & SelectionKey.OP_WRITE) == 0); - when(channel.isWritable()).thenReturn(true); socketSelector.preSelect(); - verify(writeContext).queueWriteOperations(writeOperation); + verify(channelContext).queueWriteOperation(writeOperation); assertTrue((selectionKey.interestOps() & SelectionKey.OP_WRITE) != 0); } public void testQueueDirectlyInChannelBufferSuccessful() throws Exception { - WriteOperation writeOperation = new WriteOperation(channel, buffers, listener); + BytesWriteOperation writeOperation = new BytesWriteOperation(channel, buffers, listener); assertTrue((selectionKey.interestOps() & SelectionKey.OP_WRITE) == 0); - when(channel.isWritable()).thenReturn(true); socketSelector.queueWriteInChannelBuffer(writeOperation); - verify(writeContext).queueWriteOperations(writeOperation); + verify(channelContext).queueWriteOperation(writeOperation); assertTrue((selectionKey.interestOps() & SelectionKey.OP_WRITE) != 0); } public void testQueueDirectlyInChannelBufferSelectionKeyThrowsException() throws Exception { SelectionKey selectionKey = mock(SelectionKey.class); - WriteOperation writeOperation = new WriteOperation(channel, buffers, listener); + BytesWriteOperation writeOperation = new BytesWriteOperation(channel, buffers, listener); CancelledKeyException cancelledKeyException = new CancelledKeyException(); - when(channel.isWritable()).thenReturn(true); when(channel.getSelectionKey()).thenReturn(selectionKey); when(selectionKey.interestOps(anyInt())).thenThrow(cancelledKeyException); socketSelector.queueWriteInChannelBuffer(writeOperation); - verify(writeContext, times(0)).queueWriteOperations(writeOperation); + verify(channelContext, times(0)).queueWriteOperation(writeOperation); verify(listener).accept(null, cancelledKeyException); } @@ -285,6 +281,16 @@ public class SocketSelectorTests extends ESTestCase { verify(eventHandler).readException(channel, ioException); } + public void testWillCallPostHandleAfterChannelHandling() throws Exception { + selectionKey.setReadyOps(SelectionKey.OP_WRITE | SelectionKey.OP_READ); + + socketSelector.processKey(selectionKey); + + verify(eventHandler).handleWrite(channel); + verify(eventHandler).handleRead(channel); + verify(eventHandler).postHandling(channel); + } + public void testCleanup() throws Exception { NioSocketChannel unRegisteredChannel = mock(NioSocketChannel.class); @@ -292,7 +298,7 @@ public class SocketSelectorTests extends ESTestCase { socketSelector.preSelect(); - socketSelector.queueWrite(new WriteOperation(mock(NioSocketChannel.class), buffers, listener)); + socketSelector.queueWrite(new BytesWriteOperation(mock(NioSocketChannel.class), buffers, listener)); socketSelector.scheduleForRegistration(unRegisteredChannel); TestSelectionKey testSelectionKey = new TestSelectionKey(0); diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/WriteOperationTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/WriteOperationTests.java index da74269b825..59fb9cde438 100644 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/WriteOperationTests.java +++ b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/WriteOperationTests.java @@ -45,71 +45,58 @@ public class WriteOperationTests extends ESTestCase { } - public void testFlush() throws IOException { + public void testFullyFlushedMarker() { ByteBuffer[] buffers = {ByteBuffer.allocate(10)}; - WriteOperation writeOp = new WriteOperation(channel, buffers, listener); + BytesWriteOperation writeOp = new BytesWriteOperation(channel, buffers, listener); - - when(channel.write(any(ByteBuffer[].class))).thenReturn(10); - - writeOp.flush(); + writeOp.incrementIndex(10); assertTrue(writeOp.isFullyFlushed()); } - public void testPartialFlush() throws IOException { + public void testPartiallyFlushedMarker() { ByteBuffer[] buffers = {ByteBuffer.allocate(10)}; - WriteOperation writeOp = new WriteOperation(channel, buffers, listener); + BytesWriteOperation writeOp = new BytesWriteOperation(channel, buffers, listener); - when(channel.write(any(ByteBuffer[].class))).thenReturn(5); - - writeOp.flush(); + writeOp.incrementIndex(5); assertFalse(writeOp.isFullyFlushed()); } public void testMultipleFlushesWithCompositeBuffer() throws IOException { ByteBuffer[] buffers = {ByteBuffer.allocate(10), ByteBuffer.allocate(15), ByteBuffer.allocate(3)}; - WriteOperation writeOp = new WriteOperation(channel, buffers, listener); + BytesWriteOperation writeOp = new BytesWriteOperation(channel, buffers, listener); ArgumentCaptor buffersCaptor = ArgumentCaptor.forClass(ByteBuffer[].class); - when(channel.write(buffersCaptor.capture())).thenReturn(5) - .thenReturn(5) - .thenReturn(2) - .thenReturn(15) - .thenReturn(1); - - writeOp.flush(); + writeOp.incrementIndex(5); assertFalse(writeOp.isFullyFlushed()); - writeOp.flush(); - assertFalse(writeOp.isFullyFlushed()); - writeOp.flush(); - assertFalse(writeOp.isFullyFlushed()); - writeOp.flush(); - assertFalse(writeOp.isFullyFlushed()); - writeOp.flush(); - assertTrue(writeOp.isFullyFlushed()); - - List values = buffersCaptor.getAllValues(); - ByteBuffer[] byteBuffers = values.get(0); - assertEquals(3, byteBuffers.length); - assertEquals(10, byteBuffers[0].remaining()); - - byteBuffers = values.get(1); + ByteBuffer[] byteBuffers = writeOp.getBuffersToWrite(); assertEquals(3, byteBuffers.length); assertEquals(5, byteBuffers[0].remaining()); - byteBuffers = values.get(2); + writeOp.incrementIndex(5); + assertFalse(writeOp.isFullyFlushed()); + byteBuffers = writeOp.getBuffersToWrite(); assertEquals(2, byteBuffers.length); assertEquals(15, byteBuffers[0].remaining()); - byteBuffers = values.get(3); + writeOp.incrementIndex(2); + assertFalse(writeOp.isFullyFlushed()); + byteBuffers = writeOp.getBuffersToWrite(); assertEquals(2, byteBuffers.length); assertEquals(13, byteBuffers[0].remaining()); - byteBuffers = values.get(4); + writeOp.incrementIndex(15); + assertFalse(writeOp.isFullyFlushed()); + byteBuffers = writeOp.getBuffersToWrite(); assertEquals(1, byteBuffers.length); assertEquals(1, byteBuffers[0].remaining()); + + writeOp.incrementIndex(1); + assertTrue(writeOp.isFullyFlushed()); + byteBuffers = writeOp.getBuffersToWrite(); + assertEquals(1, byteBuffers.length); + assertEquals(0, byteBuffers[0].remaining()); } } diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java index 9917bf79f59..d25d3c5974a 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java @@ -33,13 +33,12 @@ import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.nio.AcceptingSelector; import org.elasticsearch.nio.AcceptorEventHandler; -import org.elasticsearch.nio.BytesReadContext; -import org.elasticsearch.nio.BytesWriteContext; +import org.elasticsearch.nio.BytesChannelContext; +import org.elasticsearch.nio.ChannelContext; import org.elasticsearch.nio.ChannelFactory; import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioGroup; import org.elasticsearch.nio.NioSocketChannel; -import org.elasticsearch.nio.ReadContext; import org.elasticsearch.nio.SocketEventHandler; import org.elasticsearch.nio.SocketSelector; import org.elasticsearch.threadpool.ThreadPool; @@ -72,12 +71,12 @@ public class NioTransport extends TcpTransport { public static final Setting NIO_ACCEPTOR_COUNT = intSetting("transport.nio.acceptor_count", 1, 1, Setting.Property.NodeScope); - private final PageCacheRecycler pageCacheRecycler; + protected final PageCacheRecycler pageCacheRecycler; private final ConcurrentMap profileToChannelFactory = newConcurrentMap(); private volatile NioGroup nioGroup; private volatile TcpChannelFactory clientChannelFactory; - NioTransport(Settings settings, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays, + protected NioTransport(Settings settings, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays, PageCacheRecycler pageCacheRecycler, NamedWriteableRegistry namedWriteableRegistry, CircuitBreakerService circuitBreakerService) { super("nio", settings, threadPool, bigArrays, circuitBreakerService, namedWriteableRegistry, networkService); @@ -111,13 +110,13 @@ public class NioTransport extends TcpTransport { NioTransport.NIO_WORKER_COUNT.get(settings), SocketEventHandler::new); ProfileSettings clientProfileSettings = new ProfileSettings(settings, "default"); - clientChannelFactory = new TcpChannelFactory(clientProfileSettings); + clientChannelFactory = channelFactory(clientProfileSettings, true); if (useNetworkServer) { // loop through all profiles and start them up, special handling for default one for (ProfileSettings profileSettings : profileSettings) { String profileName = profileSettings.profileName; - TcpChannelFactory factory = new TcpChannelFactory(profileSettings); + TcpChannelFactory factory = channelFactory(profileSettings, false); profileToChannelFactory.putIfAbsent(profileName, factory); bindServer(profileSettings); } @@ -144,19 +143,30 @@ public class NioTransport extends TcpTransport { profileToChannelFactory.clear(); } - private void exceptionCaught(NioSocketChannel channel, Exception exception) { + protected void exceptionCaught(NioSocketChannel channel, Exception exception) { onException((TcpChannel) channel, exception); } - private void acceptChannel(NioSocketChannel channel) { + protected void acceptChannel(NioSocketChannel channel) { serverAcceptedChannel((TcpNioSocketChannel) channel); } - private class TcpChannelFactory extends ChannelFactory { + protected TcpChannelFactory channelFactory(ProfileSettings settings, boolean isClient) { + return new TcpChannelFactoryImpl(settings); + } + + protected abstract class TcpChannelFactory extends ChannelFactory { + + protected TcpChannelFactory(RawChannelFactory rawChannelFactory) { + super(rawChannelFactory); + } + } + + private class TcpChannelFactoryImpl extends TcpChannelFactory { private final String profileName; - TcpChannelFactory(TcpTransport.ProfileSettings profileSettings) { + private TcpChannelFactoryImpl(ProfileSettings profileSettings) { super(new RawChannelFactory(profileSettings.tcpNoDelay, profileSettings.tcpKeepAlive, profileSettings.reuseAddress, @@ -172,10 +182,10 @@ public class NioTransport extends TcpTransport { Recycler.V bytes = pageCacheRecycler.bytePage(false); return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close); }; - ReadContext.ReadConsumer nioReadConsumer = channelBuffer -> + ChannelContext.ReadConsumer nioReadConsumer = channelBuffer -> consumeNetworkReads(nioChannel, BytesReference.fromByteBuffers(channelBuffer.sliceBuffersTo(channelBuffer.getIndex()))); - BytesReadContext readContext = new BytesReadContext(nioChannel, nioReadConsumer, new InboundChannelBuffer(pageSupplier)); - nioChannel.setContexts(readContext, new BytesWriteContext(nioChannel), NioTransport.this::exceptionCaught); + BytesChannelContext context = new BytesChannelContext(nioChannel, nioReadConsumer, new InboundChannelBuffer(pageSupplier)); + nioChannel.setContexts(context, NioTransport.this::exceptionCaught); return nioChannel; } diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpNioServerSocketChannel.java b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpNioServerSocketChannel.java index 7f657c76348..f0d01bf5a7d 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpNioServerSocketChannel.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpNioServerSocketChannel.java @@ -38,7 +38,7 @@ public class TcpNioServerSocketChannel extends NioServerSocketChannel implements private final String profile; - TcpNioServerSocketChannel(String profile, ServerSocketChannel socketChannel, + public TcpNioServerSocketChannel(String profile, ServerSocketChannel socketChannel, ChannelFactory channelFactory, AcceptingSelector selector) throws IOException { super(socketChannel, channelFactory, selector); @@ -60,6 +60,11 @@ public class TcpNioServerSocketChannel extends NioServerSocketChannel implements return null; } + @Override + public void close() { + getSelector().queueChannelClose(this); + } + @Override public String getProfile() { return profile; diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpNioSocketChannel.java b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpNioSocketChannel.java index 5633899a04b..c2064e53ca6 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpNioSocketChannel.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpNioSocketChannel.java @@ -33,13 +33,13 @@ public class TcpNioSocketChannel extends NioSocketChannel implements TcpChannel private final String profile; - TcpNioSocketChannel(String profile, SocketChannel socketChannel, SocketSelector selector) throws IOException { + public TcpNioSocketChannel(String profile, SocketChannel socketChannel, SocketSelector selector) throws IOException { super(socketChannel, selector); this.profile = profile; } public void sendMessage(BytesReference reference, ActionListener listener) { - getWriteContext().sendMessage(BytesReference.toByteBuffers(reference), ActionListener.toBiConsumer(listener)); + getContext().sendMessage(BytesReference.toByteBuffers(reference), ActionListener.toBiConsumer(listener)); } @Override @@ -59,6 +59,11 @@ public class TcpNioSocketChannel extends NioSocketChannel implements TcpChannel addCloseListener(ActionListener.toBiConsumer(listener)); } + @Override + public void close() { + getContext().closeChannel(); + } + @Override public String toString() { return "TcpNioSocketChannel{" + diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java index a8876453b5b..c5ec4c6bfb7 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java @@ -31,14 +31,13 @@ import org.elasticsearch.common.util.PageCacheRecycler; import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.nio.AcceptingSelector; import org.elasticsearch.nio.AcceptorEventHandler; -import org.elasticsearch.nio.BytesReadContext; -import org.elasticsearch.nio.BytesWriteContext; +import org.elasticsearch.nio.BytesChannelContext; +import org.elasticsearch.nio.ChannelContext; import org.elasticsearch.nio.ChannelFactory; import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioGroup; import org.elasticsearch.nio.NioServerSocketChannel; import org.elasticsearch.nio.NioSocketChannel; -import org.elasticsearch.nio.ReadContext; import org.elasticsearch.nio.SocketSelector; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TcpChannel; @@ -162,11 +161,10 @@ public class MockNioTransport extends TcpTransport { Recycler.V bytes = pageCacheRecycler.bytePage(false); return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close); }; - ReadContext.ReadConsumer nioReadConsumer = channelBuffer -> + ChannelContext.ReadConsumer nioReadConsumer = channelBuffer -> consumeNetworkReads(nioChannel, BytesReference.fromByteBuffers(channelBuffer.sliceBuffersTo(channelBuffer.getIndex()))); - BytesReadContext readContext = new BytesReadContext(nioChannel, nioReadConsumer, new InboundChannelBuffer(pageSupplier)); - BytesWriteContext writeContext = new BytesWriteContext(nioChannel); - nioChannel.setContexts(readContext, writeContext, MockNioTransport.this::exceptionCaught); + BytesChannelContext context = new BytesChannelContext(nioChannel, nioReadConsumer, new InboundChannelBuffer(pageSupplier)); + nioChannel.setContexts(context, MockNioTransport.this::exceptionCaught); return nioChannel; } @@ -188,6 +186,11 @@ public class MockNioTransport extends TcpTransport { this.profile = profile; } + @Override + public void close() { + getSelector().queueChannelClose(this); + } + @Override public String getProfile() { return profile; @@ -224,6 +227,11 @@ public class MockNioTransport extends TcpTransport { this.profile = profile; } + @Override + public void close() { + getContext().closeChannel(); + } + @Override public String getProfile() { return profile; @@ -243,7 +251,7 @@ public class MockNioTransport extends TcpTransport { @Override public void sendMessage(BytesReference reference, ActionListener listener) { - getWriteContext().sendMessage(BytesReference.toByteBuffers(reference), ActionListener.toBiConsumer(listener)); + getContext().sendMessage(BytesReference.toByteBuffers(reference), ActionListener.toBiConsumer(listener)); } } }