Unify nio read / write channel contexts (#28160)

This commit is related to #27260. Right now we have separate read and
write contexts for implementing specific protocol logic. However, some
protocols require a closer relationship between read and write
operations than is allowed by our current model. An example is HTTP
which might require a write if some problem with request parsing was
encountered.

Additionally, some protocols require close messages to be sent when a
channel is shutdown. This is also problematic in our current model,
where we assume that channels should simply be queued for close and
forgotten.

This commit transitions to a single ChannelContext which implements
all read, write, and close logic for protocols. It is the job of the
context to tell the selector when to close the channel. A channel can
still be manually queued for close with a selector. This is how server
channels are closed for now. And this route allows timeout mechanisms on
normal channel closes to be implemented.
This commit is contained in:
Tim Brooks 2018-01-17 09:44:21 -07:00 committed by GitHub
parent 1f66672d6f
commit 4ea9ddb7d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 1028 additions and 893 deletions

View File

@ -26,7 +26,6 @@ import java.nio.channels.NetworkChannel;
import java.nio.channels.SelectableChannel; import java.nio.channels.SelectableChannel;
import java.nio.channels.SelectionKey; import java.nio.channels.SelectionKey;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer; import java.util.function.BiConsumer;
/** /**
@ -48,9 +47,6 @@ import java.util.function.BiConsumer;
public abstract class AbstractNioChannel<S extends SelectableChannel & NetworkChannel> implements NioChannel { public abstract class AbstractNioChannel<S extends SelectableChannel & NetworkChannel> implements NioChannel {
final S socketChannel; final S socketChannel;
// This indicates if the channel has been scheduled to be closed. Read the closeFuture to determine if
// the channel close process has completed.
final AtomicBoolean isClosing = new AtomicBoolean(false);
private final InetSocketAddress localAddress; private final InetSocketAddress localAddress;
private final CompletableFuture<Void> closeContext = new CompletableFuture<>(); private final CompletableFuture<Void> closeContext = new CompletableFuture<>();
@ -73,21 +69,6 @@ public abstract class AbstractNioChannel<S extends SelectableChannel & NetworkCh
return localAddress; return localAddress;
} }
/**
* Schedules a channel to be closed by the selector event loop with which it is registered.
* <p>
* If the channel is open and the state can be transitioned to closed, the close operation will
* be scheduled with the event loop.
* <p>
* If the channel is already set to closed, it is assumed that it is already scheduled to be closed.
*/
@Override
public void close() {
if (isClosing.compareAndSet(false, true)) {
selector.queueChannelClose(this);
}
}
/** /**
* Closes the channel synchronously. This method should only be called from the selector thread. * Closes the channel synchronously. This method should only be called from the selector thread.
* <p> * <p>
@ -95,8 +76,7 @@ public abstract class AbstractNioChannel<S extends SelectableChannel & NetworkCh
*/ */
@Override @Override
public void closeFromSelector() throws IOException { public void closeFromSelector() throws IOException {
assert selector.isOnCurrentThread() : "Should only call from selector thread"; selector.assertOnSelectorThread();
isClosing.set(true);
if (closeContext.isDone() == false) { if (closeContext.isDone() == false) {
try { try {
closeRawChannel(); closeRawChannel();

View File

@ -0,0 +1,169 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.nio;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.util.LinkedList;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
public class BytesChannelContext implements ChannelContext {
private final NioSocketChannel channel;
private final ReadConsumer readConsumer;
private final InboundChannelBuffer channelBuffer;
private final LinkedList<BytesWriteOperation> queued = new LinkedList<>();
private final AtomicBoolean isClosing = new AtomicBoolean(false);
private boolean peerClosed = false;
private boolean ioException = false;
public BytesChannelContext(NioSocketChannel channel, ReadConsumer readConsumer, InboundChannelBuffer channelBuffer) {
this.channel = channel;
this.readConsumer = readConsumer;
this.channelBuffer = channelBuffer;
}
@Override
public void channelRegistered() throws IOException {}
@Override
public int read() throws IOException {
if (channelBuffer.getRemaining() == 0) {
// Requiring one additional byte will ensure that a new page is allocated.
channelBuffer.ensureCapacity(channelBuffer.getCapacity() + 1);
}
int bytesRead;
try {
bytesRead = channel.read(channelBuffer.sliceBuffersFrom(channelBuffer.getIndex()));
} catch (IOException ex) {
ioException = true;
throw ex;
}
if (bytesRead == -1) {
peerClosed = true;
return 0;
}
channelBuffer.incrementIndex(bytesRead);
int bytesConsumed = Integer.MAX_VALUE;
while (bytesConsumed > 0 && channelBuffer.getIndex() > 0) {
bytesConsumed = readConsumer.consumeReads(channelBuffer);
channelBuffer.release(bytesConsumed);
}
return bytesRead;
}
@Override
public void sendMessage(ByteBuffer[] buffers, BiConsumer<Void, Throwable> listener) {
if (isClosing.get()) {
listener.accept(null, new ClosedChannelException());
return;
}
BytesWriteOperation writeOperation = new BytesWriteOperation(channel, buffers, listener);
SocketSelector selector = channel.getSelector();
if (selector.isOnCurrentThread() == false) {
selector.queueWrite(writeOperation);
return;
}
// TODO: Eval if we will allow writes from sendMessage
selector.queueWriteInChannelBuffer(writeOperation);
}
@Override
public void queueWriteOperation(WriteOperation writeOperation) {
channel.getSelector().assertOnSelectorThread();
queued.add((BytesWriteOperation) writeOperation);
}
@Override
public void flushChannel() throws IOException {
channel.getSelector().assertOnSelectorThread();
int ops = queued.size();
if (ops == 1) {
singleFlush(queued.pop());
} else if (ops > 1) {
multiFlush();
}
}
@Override
public boolean hasQueuedWriteOps() {
channel.getSelector().assertOnSelectorThread();
return queued.isEmpty() == false;
}
@Override
public void closeChannel() {
if (isClosing.compareAndSet(false, true)) {
channel.getSelector().queueChannelClose(channel);
}
}
@Override
public boolean selectorShouldClose() {
return peerClosed || ioException || isClosing.get();
}
@Override
public void closeFromSelector() {
channel.getSelector().assertOnSelectorThread();
// Set to true in order to reject new writes before queuing with selector
isClosing.set(true);
channelBuffer.close();
for (BytesWriteOperation op : queued) {
channel.getSelector().executeFailedListener(op.getListener(), new ClosedChannelException());
}
queued.clear();
}
private void singleFlush(BytesWriteOperation headOp) throws IOException {
try {
int written = channel.write(headOp.getBuffersToWrite());
headOp.incrementIndex(written);
} catch (IOException e) {
channel.getSelector().executeFailedListener(headOp.getListener(), e);
ioException = true;
throw e;
}
if (headOp.isFullyFlushed()) {
channel.getSelector().executeListener(headOp.getListener(), null);
} else {
queued.push(headOp);
}
}
private void multiFlush() throws IOException {
boolean lastOpCompleted = true;
while (lastOpCompleted && queued.isEmpty() == false) {
BytesWriteOperation op = queued.pop();
singleFlush(op);
lastOpCompleted = op.isFullyFlushed();
}
}
}

View File

@ -1,64 +0,0 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.nio;
import java.io.IOException;
public class BytesReadContext implements ReadContext {
private final NioSocketChannel channel;
private final ReadConsumer readConsumer;
private final InboundChannelBuffer channelBuffer;
public BytesReadContext(NioSocketChannel channel, ReadConsumer readConsumer, InboundChannelBuffer channelBuffer) {
this.channel = channel;
this.channelBuffer = channelBuffer;
this.readConsumer = readConsumer;
}
@Override
public int read() throws IOException {
if (channelBuffer.getRemaining() == 0) {
// Requiring one additional byte will ensure that a new page is allocated.
channelBuffer.ensureCapacity(channelBuffer.getCapacity() + 1);
}
int bytesRead = channel.read(channelBuffer.sliceBuffersFrom(channelBuffer.getIndex()));
if (bytesRead == -1) {
return bytesRead;
}
channelBuffer.incrementIndex(bytesRead);
int bytesConsumed = Integer.MAX_VALUE;
while (bytesConsumed > 0) {
bytesConsumed = readConsumer.consumeReads(channelBuffer);
channelBuffer.release(bytesConsumed);
}
return bytesRead;
}
@Override
public void close() {
channelBuffer.close();
}
}

View File

@ -1,111 +0,0 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.nio;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.util.LinkedList;
import java.util.function.BiConsumer;
public class BytesWriteContext implements WriteContext {
private final NioSocketChannel channel;
private final LinkedList<WriteOperation> queued = new LinkedList<>();
public BytesWriteContext(NioSocketChannel channel) {
this.channel = channel;
}
@Override
public void sendMessage(Object message, BiConsumer<Void, Throwable> listener) {
ByteBuffer[] buffers = (ByteBuffer[]) message;
if (channel.isWritable() == false) {
listener.accept(null, new ClosedChannelException());
return;
}
WriteOperation writeOperation = new WriteOperation(channel, buffers, listener);
SocketSelector selector = channel.getSelector();
if (selector.isOnCurrentThread() == false) {
selector.queueWrite(writeOperation);
return;
}
// TODO: Eval if we will allow writes from sendMessage
selector.queueWriteInChannelBuffer(writeOperation);
}
@Override
public void queueWriteOperations(WriteOperation writeOperation) {
assert channel.getSelector().isOnCurrentThread() : "Must be on selector thread to queue writes";
queued.add(writeOperation);
}
@Override
public void flushChannel() throws IOException {
assert channel.getSelector().isOnCurrentThread() : "Must be on selector thread to flush writes";
int ops = queued.size();
if (ops == 1) {
singleFlush(queued.pop());
} else if (ops > 1) {
multiFlush();
}
}
@Override
public boolean hasQueuedWriteOps() {
assert channel.getSelector().isOnCurrentThread() : "Must be on selector thread to access queued writes";
return queued.isEmpty() == false;
}
@Override
public void clearQueuedWriteOps(Exception e) {
assert channel.getSelector().isOnCurrentThread() : "Must be on selector thread to clear queued writes";
for (WriteOperation op : queued) {
channel.getSelector().executeFailedListener(op.getListener(), e);
}
queued.clear();
}
private void singleFlush(WriteOperation headOp) throws IOException {
try {
headOp.flush();
} catch (IOException e) {
channel.getSelector().executeFailedListener(headOp.getListener(), e);
throw e;
}
if (headOp.isFullyFlushed()) {
channel.getSelector().executeListener(headOp.getListener(), null);
} else {
queued.push(headOp);
}
}
private void multiFlush() throws IOException {
boolean lastOpCompleted = true;
while (lastOpCompleted && queued.isEmpty() == false) {
WriteOperation op = queued.pop();
singleFlush(op);
lastOpCompleted = op.isFullyFlushed();
}
}
}

View File

@ -0,0 +1,88 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.nio;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.function.BiConsumer;
public class BytesWriteOperation implements WriteOperation {
private final NioSocketChannel channel;
private final BiConsumer<Void, Throwable> listener;
private final ByteBuffer[] buffers;
private final int[] offsets;
private final int length;
private int internalIndex;
public BytesWriteOperation(NioSocketChannel channel, ByteBuffer[] buffers, BiConsumer<Void, Throwable> listener) {
this.channel = channel;
this.listener = listener;
this.buffers = buffers;
this.offsets = new int[buffers.length];
int offset = 0;
for (int i = 0; i < buffers.length; i++) {
ByteBuffer buffer = buffers[i];
offsets[i] = offset;
offset += buffer.remaining();
}
length = offset;
}
@Override
public BiConsumer<Void, Throwable> getListener() {
return listener;
}
@Override
public NioSocketChannel getChannel() {
return channel;
}
public boolean isFullyFlushed() {
assert length >= internalIndex : "Should never have an index that is greater than the length [length=" + length + ", index="
+ internalIndex + "]";
return internalIndex == length;
}
public void incrementIndex(int delta) {
internalIndex += delta;
assert length >= internalIndex : "Should never increment index past length [length=" + length + ", post-increment index="
+ internalIndex + ", delta=" + delta + "]";
}
public ByteBuffer[] getBuffersToWrite() {
final int index = Arrays.binarySearch(offsets, internalIndex);
int offsetIndex = index < 0 ? (-(index + 1)) - 1 : index;
ByteBuffer[] postIndexBuffers = new ByteBuffer[buffers.length - offsetIndex];
ByteBuffer firstBuffer = buffers[offsetIndex].duplicate();
firstBuffer.position(internalIndex - offsets[offsetIndex]);
postIndexBuffers[0] = firstBuffer;
int j = 1;
for (int i = (offsetIndex + 1); i < buffers.length; ++i) {
postIndexBuffers[j++] = buffers[i].duplicate();
}
return postIndexBuffers;
}
}

View File

@ -0,0 +1,81 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.nio;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.function.BiConsumer;
/**
* This context should implement the specific logic for a channel. When a channel receives a notification
* that it is ready to perform certain operations (read, write, etc) the {@link ChannelContext} will be
* called. This context will need to implement all protocol related logic. Additionally, if any special
* close behavior is required, it should be implemented in this context.
*
* The only methods of the context that should ever be called from a non-selector thread are
* {@link #closeChannel()} and {@link #sendMessage(ByteBuffer[], BiConsumer)}.
*/
public interface ChannelContext {
void channelRegistered() throws IOException;
int read() throws IOException;
void sendMessage(ByteBuffer[] buffers, BiConsumer<Void, Throwable> listener);
void queueWriteOperation(WriteOperation writeOperation);
void flushChannel() throws IOException;
boolean hasQueuedWriteOps();
/**
* Schedules a channel to be closed by the selector event loop with which it is registered.
* <p>
* If the channel is open and the state can be transitioned to closed, the close operation will
* be scheduled with the event loop.
* <p>
* If the channel is already set to closed, it is assumed that it is already scheduled to be closed.
* <p>
* Depending on the underlying protocol of the channel, a close operation might simply close the socket
* channel or may involve reading and writing messages.
*/
void closeChannel();
/**
* This method indicates if a selector should close this channel.
*
* @return a boolean indicating if the selector should close
*/
boolean selectorShouldClose();
/**
* This method cleans up any context resources that need to be released when a channel is closed. It
* should only be called by the selector thread.
*
* @throws IOException during channel / context close
*/
void closeFromSelector() throws IOException;
@FunctionalInterface
interface ReadConsumer {
int consumeReads(InboundChannelBuffer channelBuffer) throws IOException;
}
}

View File

@ -88,8 +88,7 @@ public abstract class ChannelFactory<ServerSocket extends NioServerSocketChannel
private Socket internalCreateChannel(SocketSelector selector, SocketChannel rawChannel) throws IOException { private Socket internalCreateChannel(SocketSelector selector, SocketChannel rawChannel) throws IOException {
try { try {
Socket channel = createChannel(selector, rawChannel); Socket channel = createChannel(selector, rawChannel);
assert channel.getReadContext() != null : "read context should have been set on channel"; assert channel.getContext() != null : "channel context should have been set on channel";
assert channel.getWriteContext() != null : "write context should have been set on channel";
assert channel.getExceptionContext() != null : "exception handler should have been set on channel"; assert channel.getExceptionContext() != null : "exception handler should have been set on channel";
return channel; return channel;
} catch (Exception e) { } catch (Exception e) {

View File

@ -163,6 +163,11 @@ public abstract class ESSelector implements Closeable {
return Thread.currentThread() == thread; return Thread.currentThread() == thread;
} }
public void assertOnSelectorThread() {
assert isOnCurrentThread() : "Must be on selector thread to perform this operation. Currently on thread ["
+ Thread.currentThread().getName() + "].";
}
void wakeup() { void wakeup() {
// TODO: Do we need the wakeup optimizations that some other libraries use? // TODO: Do we need the wakeup optimizations that some other libraries use?
selector.wakeup(); selector.wakeup();

View File

@ -32,8 +32,6 @@ public interface NioChannel {
InetSocketAddress getLocalAddress(); InetSocketAddress getLocalAddress();
void close();
void closeFromSelector() throws IOException; void closeFromSelector() throws IOException;
void register() throws ClosedChannelException; void register() throws ClosedChannelException;

View File

@ -19,11 +19,13 @@
package org.elasticsearch.nio; package org.elasticsearch.nio;
import org.elasticsearch.nio.utils.ExceptionsHelper;
import java.io.IOException; import java.io.IOException;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.SocketChannel; import java.nio.channels.SocketChannel;
import java.util.ArrayList;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer; import java.util.function.BiConsumer;
@ -34,8 +36,7 @@ public class NioSocketChannel extends AbstractNioChannel<SocketChannel> {
private final CompletableFuture<Void> connectContext = new CompletableFuture<>(); private final CompletableFuture<Void> connectContext = new CompletableFuture<>();
private final SocketSelector socketSelector; private final SocketSelector socketSelector;
private final AtomicBoolean contextsSet = new AtomicBoolean(false); private final AtomicBoolean contextsSet = new AtomicBoolean(false);
private WriteContext writeContext; private ChannelContext context;
private ReadContext readContext;
private BiConsumer<NioSocketChannel, Exception> exceptionContext; private BiConsumer<NioSocketChannel, Exception> exceptionContext;
private Exception connectException; private Exception connectException;
@ -47,14 +48,21 @@ public class NioSocketChannel extends AbstractNioChannel<SocketChannel> {
@Override @Override
public void closeFromSelector() throws IOException { public void closeFromSelector() throws IOException {
assert socketSelector.isOnCurrentThread() : "Should only call from selector thread"; getSelector().assertOnSelectorThread();
// Even if the channel has already been closed we will clear any pending write operations just in case if (isOpen()) {
if (writeContext.hasQueuedWriteOps()) { ArrayList<IOException> closingExceptions = new ArrayList<>(2);
writeContext.clearQueuedWriteOps(new ClosedChannelException()); try {
}
readContext.close();
super.closeFromSelector(); super.closeFromSelector();
} catch (IOException e) {
closingExceptions.add(e);
}
try {
context.closeFromSelector();
} catch (IOException e) {
closingExceptions.add(e);
}
ExceptionsHelper.rethrowAndSuppress(closingExceptions);
}
} }
@Override @Override
@ -62,6 +70,10 @@ public class NioSocketChannel extends AbstractNioChannel<SocketChannel> {
return socketSelector; return socketSelector;
} }
public int write(ByteBuffer buffer) throws IOException {
return socketChannel.write(buffer);
}
public int write(ByteBuffer[] buffers) throws IOException { public int write(ByteBuffer[] buffers) throws IOException {
if (buffers.length == 1) { if (buffers.length == 1) {
return socketChannel.write(buffers[0]); return socketChannel.write(buffers[0]);
@ -82,33 +94,17 @@ public class NioSocketChannel extends AbstractNioChannel<SocketChannel> {
} }
} }
public int read(InboundChannelBuffer buffer) throws IOException { public void setContexts(ChannelContext context, BiConsumer<NioSocketChannel, Exception> exceptionContext) {
int bytesRead = (int) socketChannel.read(buffer.sliceBuffersFrom(buffer.getIndex()));
if (bytesRead == -1) {
return bytesRead;
}
buffer.incrementIndex(bytesRead);
return bytesRead;
}
public void setContexts(ReadContext readContext, WriteContext writeContext, BiConsumer<NioSocketChannel, Exception> exceptionContext) {
if (contextsSet.compareAndSet(false, true)) { if (contextsSet.compareAndSet(false, true)) {
this.readContext = readContext; this.context = context;
this.writeContext = writeContext;
this.exceptionContext = exceptionContext; this.exceptionContext = exceptionContext;
} else { } else {
throw new IllegalStateException("Contexts on this channel were already set. They should only be once."); throw new IllegalStateException("Contexts on this channel were already set. They should only be once.");
} }
} }
public WriteContext getWriteContext() { public ChannelContext getContext() {
return writeContext; return context;
}
public ReadContext getReadContext() {
return readContext;
} }
public BiConsumer<NioSocketChannel, Exception> getExceptionContext() { public BiConsumer<NioSocketChannel, Exception> getExceptionContext() {
@ -123,14 +119,6 @@ public class NioSocketChannel extends AbstractNioChannel<SocketChannel> {
return isConnectComplete0(); return isConnectComplete0();
} }
public boolean isWritable() {
return isClosing.get() == false;
}
public boolean isReadable() {
return isClosing.get() == false;
}
/** /**
* This method will attempt to complete the connection process for this channel. It should be called for * This method will attempt to complete the connection process for this channel. It should be called for
* new channels or for a channel that has produced a OP_CONNECT event. If this method returns true then * new channels or for a channel that has produced a OP_CONNECT event. If this method returns true then

View File

@ -1,35 +0,0 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.nio;
import java.io.IOException;
public interface ReadContext extends AutoCloseable {
int read() throws IOException;
@Override
void close();
@FunctionalInterface
interface ReadConsumer {
int consumeReads(InboundChannelBuffer channelBuffer) throws IOException;
}
}

View File

@ -26,28 +26,81 @@ public final class SelectionKeyUtils {
private SelectionKeyUtils() {} private SelectionKeyUtils() {}
/**
* Adds an interest in writes for this channel while maintaining other interests.
*
* @param channel the channel
* @throws CancelledKeyException if the key was already cancelled
*/
public static void setWriteInterested(NioChannel channel) throws CancelledKeyException { public static void setWriteInterested(NioChannel channel) throws CancelledKeyException {
SelectionKey selectionKey = channel.getSelectionKey(); SelectionKey selectionKey = channel.getSelectionKey();
selectionKey.interestOps(selectionKey.interestOps() | SelectionKey.OP_WRITE); selectionKey.interestOps(selectionKey.interestOps() | SelectionKey.OP_WRITE);
} }
/**
* Removes an interest in writes for this channel while maintaining other interests.
*
* @param channel the channel
* @throws CancelledKeyException if the key was already cancelled
*/
public static void removeWriteInterested(NioChannel channel) throws CancelledKeyException { public static void removeWriteInterested(NioChannel channel) throws CancelledKeyException {
SelectionKey selectionKey = channel.getSelectionKey(); SelectionKey selectionKey = channel.getSelectionKey();
selectionKey.interestOps(selectionKey.interestOps() & ~SelectionKey.OP_WRITE); selectionKey.interestOps(selectionKey.interestOps() & ~SelectionKey.OP_WRITE);
} }
/**
* Removes an interest in connects and reads for this channel while maintaining other interests.
*
* @param channel the channel
* @throws CancelledKeyException if the key was already cancelled
*/
public static void setConnectAndReadInterested(NioChannel channel) throws CancelledKeyException { public static void setConnectAndReadInterested(NioChannel channel) throws CancelledKeyException {
SelectionKey selectionKey = channel.getSelectionKey(); SelectionKey selectionKey = channel.getSelectionKey();
selectionKey.interestOps(selectionKey.interestOps() | SelectionKey.OP_CONNECT | SelectionKey.OP_READ); selectionKey.interestOps(selectionKey.interestOps() | SelectionKey.OP_CONNECT | SelectionKey.OP_READ);
} }
/**
* Removes an interest in connects, reads, and writes for this channel while maintaining other interests.
*
* @param channel the channel
* @throws CancelledKeyException if the key was already cancelled
*/
public static void setConnectReadAndWriteInterested(NioChannel channel) throws CancelledKeyException {
SelectionKey selectionKey = channel.getSelectionKey();
selectionKey.interestOps(selectionKey.interestOps() | SelectionKey.OP_CONNECT | SelectionKey.OP_READ | SelectionKey.OP_WRITE);
}
/**
* Removes an interest in connects for this channel while maintaining other interests.
*
* @param channel the channel
* @throws CancelledKeyException if the key was already cancelled
*/
public static void removeConnectInterested(NioChannel channel) throws CancelledKeyException { public static void removeConnectInterested(NioChannel channel) throws CancelledKeyException {
SelectionKey selectionKey = channel.getSelectionKey(); SelectionKey selectionKey = channel.getSelectionKey();
selectionKey.interestOps(selectionKey.interestOps() & ~SelectionKey.OP_CONNECT); selectionKey.interestOps(selectionKey.interestOps() & ~SelectionKey.OP_CONNECT);
} }
public static void setAcceptInterested(NioServerSocketChannel channel) { /**
* Adds an interest in accepts for this channel while maintaining other interests.
*
* @param channel the channel
* @throws CancelledKeyException if the key was already cancelled
*/
public static void setAcceptInterested(NioServerSocketChannel channel) throws CancelledKeyException {
SelectionKey selectionKey = channel.getSelectionKey(); SelectionKey selectionKey = channel.getSelectionKey();
selectionKey.interestOps(selectionKey.interestOps() | SelectionKey.OP_ACCEPT); selectionKey.interestOps(selectionKey.interestOps() | SelectionKey.OP_ACCEPT);
} }
/**
* Checks for an interest in writes for this channel.
*
* @param channel the channel
* @return a boolean indicating if we are currently interested in writes for this channel
* @throws CancelledKeyException if the key was already cancelled
*/
public static boolean isWriteInterested(NioSocketChannel channel) throws CancelledKeyException {
return (channel.getSelectionKey().interestOps() & SelectionKey.OP_WRITE) != 0;
}
} }

View File

@ -43,9 +43,15 @@ public class SocketEventHandler extends EventHandler {
* *
* @param channel that was registered * @param channel that was registered
*/ */
protected void handleRegistration(NioSocketChannel channel) { protected void handleRegistration(NioSocketChannel channel) throws IOException {
ChannelContext context = channel.getContext();
context.channelRegistered();
if (context.hasQueuedWriteOps()) {
SelectionKeyUtils.setConnectReadAndWriteInterested(channel);
} else {
SelectionKeyUtils.setConnectAndReadInterested(channel); SelectionKeyUtils.setConnectAndReadInterested(channel);
} }
}
/** /**
* This method is called when an attempt to register a channel throws an exception. * This method is called when an attempt to register a channel throws an exception.
@ -86,10 +92,7 @@ public class SocketEventHandler extends EventHandler {
* @param channel that can be read * @param channel that can be read
*/ */
protected void handleRead(NioSocketChannel channel) throws IOException { protected void handleRead(NioSocketChannel channel) throws IOException {
int bytesRead = channel.getReadContext().read(); channel.getContext().read();
if (bytesRead == -1) {
handleClose(channel);
}
} }
/** /**
@ -107,16 +110,11 @@ public class SocketEventHandler extends EventHandler {
* This method is called when a channel signals it is ready to receive writes. All of the write logic * This method is called when a channel signals it is ready to receive writes. All of the write logic
* should occur in this call. * should occur in this call.
* *
* @param channel that can be read * @param channel that can be written to
*/ */
protected void handleWrite(NioSocketChannel channel) throws IOException { protected void handleWrite(NioSocketChannel channel) throws IOException {
WriteContext channelContext = channel.getWriteContext(); ChannelContext channelContext = channel.getContext();
channelContext.flushChannel(); channelContext.flushChannel();
if (channelContext.hasQueuedWriteOps()) {
SelectionKeyUtils.setWriteInterested(channel);
} else {
SelectionKeyUtils.removeWriteInterested(channel);
}
} }
/** /**
@ -153,6 +151,23 @@ public class SocketEventHandler extends EventHandler {
logger.warn(new ParameterizedMessage("exception while executing listener: {}", listener), exception); logger.warn(new ParameterizedMessage("exception while executing listener: {}", listener), exception);
} }
/**
* @param channel that was handled
*/
protected void postHandling(NioSocketChannel channel) {
if (channel.getContext().selectorShouldClose()) {
handleClose(channel);
} else {
boolean currentlyWriteInterested = SelectionKeyUtils.isWriteInterested(channel);
boolean pendingWrites = channel.getContext().hasQueuedWriteOps();
if (currentlyWriteInterested == false && pendingWrites) {
SelectionKeyUtils.setWriteInterested(channel);
} else if (currentlyWriteInterested && pendingWrites == false) {
SelectionKeyUtils.removeWriteInterested(channel);
}
}
}
private void exceptionCaught(NioSocketChannel channel, Exception e) { private void exceptionCaught(NioSocketChannel channel, Exception e) {
channel.getExceptionContext().accept(channel, e); channel.getExceptionContext().accept(channel, e);
} }

View File

@ -64,6 +64,8 @@ public class SocketSelector extends ESSelector {
handleRead(nioSocketChannel); handleRead(nioSocketChannel);
} }
} }
eventHandler.postHandling(nioSocketChannel);
} }
@Override @Override
@ -118,12 +120,12 @@ public class SocketSelector extends ESSelector {
* @param writeOperation to be queued in a channel's buffer * @param writeOperation to be queued in a channel's buffer
*/ */
public void queueWriteInChannelBuffer(WriteOperation writeOperation) { public void queueWriteInChannelBuffer(WriteOperation writeOperation) {
assert isOnCurrentThread() : "Must be on selector thread"; assertOnSelectorThread();
NioSocketChannel channel = writeOperation.getChannel(); NioSocketChannel channel = writeOperation.getChannel();
WriteContext context = channel.getWriteContext(); ChannelContext context = channel.getContext();
try { try {
SelectionKeyUtils.setWriteInterested(channel); SelectionKeyUtils.setWriteInterested(channel);
context.queueWriteOperations(writeOperation); context.queueWriteOperation(writeOperation);
} catch (Exception e) { } catch (Exception e) {
executeFailedListener(writeOperation.getListener(), e); executeFailedListener(writeOperation.getListener(), e);
} }
@ -137,7 +139,7 @@ public class SocketSelector extends ESSelector {
* @param value to provide to listener * @param value to provide to listener
*/ */
public <V> void executeListener(BiConsumer<V, Throwable> listener, V value) { public <V> void executeListener(BiConsumer<V, Throwable> listener, V value) {
assert isOnCurrentThread() : "Must be on selector thread"; assertOnSelectorThread();
try { try {
listener.accept(value, null); listener.accept(value, null);
} catch (Exception e) { } catch (Exception e) {
@ -153,7 +155,7 @@ public class SocketSelector extends ESSelector {
* @param exception to provide to listener * @param exception to provide to listener
*/ */
public <V> void executeFailedListener(BiConsumer<V, Throwable> listener, Exception exception) { public <V> void executeFailedListener(BiConsumer<V, Throwable> listener, Exception exception) {
assert isOnCurrentThread() : "Must be on selector thread"; assertOnSelectorThread();
try { try {
listener.accept(null, exception); listener.accept(null, exception);
} catch (Exception e) { } catch (Exception e) {
@ -180,7 +182,7 @@ public class SocketSelector extends ESSelector {
private void handleQueuedWrites() { private void handleQueuedWrites() {
WriteOperation writeOperation; WriteOperation writeOperation;
while ((writeOperation = queuedWrites.poll()) != null) { while ((writeOperation = queuedWrites.poll()) != null) {
if (writeOperation.getChannel().isWritable()) { if (writeOperation.getChannel().isOpen()) {
queueWriteInChannelBuffer(writeOperation); queueWriteInChannelBuffer(writeOperation);
} else { } else {
executeFailedListener(writeOperation.getListener(), new ClosedChannelException()); executeFailedListener(writeOperation.getListener(), new ClosedChannelException());

View File

@ -1,37 +0,0 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.nio;
import java.io.IOException;
import java.util.function.BiConsumer;
public interface WriteContext {
void sendMessage(Object message, BiConsumer<Void, Throwable> listener);
void queueWriteOperations(WriteOperation writeOperation);
void flushChannel() throws IOException;
boolean hasQueuedWriteOps();
void clearQueuedWriteOps(Exception e);
}

View File

@ -19,74 +19,16 @@
package org.elasticsearch.nio; package org.elasticsearch.nio;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.function.BiConsumer; import java.util.function.BiConsumer;
public class WriteOperation { /**
* This is a basic write operation that can be queued with a channel. The only requirements of a write
* operation is that is has a listener and a reference to its channel. The actual conversion of the write
* operation implementation to bytes will be performed by the {@link ChannelContext}.
*/
public interface WriteOperation {
private final NioSocketChannel channel; BiConsumer<Void, Throwable> getListener();
private final BiConsumer<Void, Throwable> listener;
private final ByteBuffer[] buffers;
private final int[] offsets;
private final int length;
private int internalIndex;
public WriteOperation(NioSocketChannel channel, ByteBuffer[] buffers, BiConsumer<Void, Throwable> listener) { NioSocketChannel getChannel();
this.channel = channel;
this.listener = listener;
this.buffers = buffers;
this.offsets = new int[buffers.length];
int offset = 0;
for (int i = 0; i < buffers.length; i++) {
ByteBuffer buffer = buffers[i];
offsets[i] = offset;
offset += buffer.remaining();
}
length = offset;
}
public ByteBuffer[] getByteBuffers() {
return buffers;
}
public BiConsumer<Void, Throwable> getListener() {
return listener;
}
public NioSocketChannel getChannel() {
return channel;
}
public boolean isFullyFlushed() {
return internalIndex == length;
}
public int flush() throws IOException {
int written = channel.write(getBuffersToWrite());
internalIndex += written;
return written;
}
private ByteBuffer[] getBuffersToWrite() {
int offsetIndex = getOffsetIndex(internalIndex);
ByteBuffer[] postIndexBuffers = new ByteBuffer[buffers.length - offsetIndex];
ByteBuffer firstBuffer = buffers[offsetIndex].duplicate();
firstBuffer.position(internalIndex - offsets[offsetIndex]);
postIndexBuffers[0] = firstBuffer;
int j = 1;
for (int i = (offsetIndex + 1); i < buffers.length; ++i) {
postIndexBuffers[j++] = buffers[i].duplicate();
}
return postIndexBuffers;
}
private int getOffsetIndex(int offset) {
final int i = Arrays.binarySearch(offsets, offset);
return i < 0 ? (-(i + 1)) - 1 : i;
}
} }

View File

@ -80,7 +80,7 @@ public class AcceptorEventHandlerTests extends ESTestCase {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public void testHandleAcceptCallsServerAcceptCallback() throws IOException { public void testHandleAcceptCallsServerAcceptCallback() throws IOException {
NioSocketChannel childChannel = new NioSocketChannel(mock(SocketChannel.class), socketSelector); NioSocketChannel childChannel = new NioSocketChannel(mock(SocketChannel.class), socketSelector);
childChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class), mock(BiConsumer.class)); childChannel.setContexts(mock(ChannelContext.class), mock(BiConsumer.class));
when(channelFactory.acceptNioChannel(same(channel), same(socketSelector))).thenReturn(childChannel); when(channelFactory.acceptNioChannel(same(channel), same(socketSelector))).thenReturn(childChannel);
handler.acceptChannel(channel); handler.acceptChannel(channel);

View File

@ -0,0 +1,337 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.nio;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.test.ESTestCase;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.isNull;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class BytesChannelContextTests extends ESTestCase {
private ChannelContext.ReadConsumer readConsumer;
private NioSocketChannel channel;
private BytesChannelContext context;
private InboundChannelBuffer channelBuffer;
private SocketSelector selector;
private BiConsumer<Void, Throwable> listener;
private int messageLength;
@Before
@SuppressWarnings("unchecked")
public void init() {
readConsumer = mock(ChannelContext.ReadConsumer.class);
messageLength = randomInt(96) + 20;
selector = mock(SocketSelector.class);
listener = mock(BiConsumer.class);
channel = mock(NioSocketChannel.class);
Supplier<InboundChannelBuffer.Page> pageSupplier = () ->
new InboundChannelBuffer.Page(ByteBuffer.allocate(BigArrays.BYTE_PAGE_SIZE), () -> {});
channelBuffer = new InboundChannelBuffer(pageSupplier);
context = new BytesChannelContext(channel, readConsumer, channelBuffer);
when(channel.getSelector()).thenReturn(selector);
when(selector.isOnCurrentThread()).thenReturn(true);
}
public void testSuccessfulRead() throws IOException {
byte[] bytes = createMessage(messageLength);
when(channel.read(any(ByteBuffer[].class))).thenAnswer(invocationOnMock -> {
ByteBuffer[] buffers = (ByteBuffer[]) invocationOnMock.getArguments()[0];
buffers[0].put(bytes);
return bytes.length;
});
when(readConsumer.consumeReads(channelBuffer)).thenReturn(messageLength, 0);
assertEquals(messageLength, context.read());
assertEquals(0, channelBuffer.getIndex());
assertEquals(BigArrays.BYTE_PAGE_SIZE - bytes.length, channelBuffer.getCapacity());
verify(readConsumer, times(1)).consumeReads(channelBuffer);
}
public void testMultipleReadsConsumed() throws IOException {
byte[] bytes = createMessage(messageLength * 2);
when(channel.read(any(ByteBuffer[].class))).thenAnswer(invocationOnMock -> {
ByteBuffer[] buffers = (ByteBuffer[]) invocationOnMock.getArguments()[0];
buffers[0].put(bytes);
return bytes.length;
});
when(readConsumer.consumeReads(channelBuffer)).thenReturn(messageLength, messageLength, 0);
assertEquals(bytes.length, context.read());
assertEquals(0, channelBuffer.getIndex());
assertEquals(BigArrays.BYTE_PAGE_SIZE - bytes.length, channelBuffer.getCapacity());
verify(readConsumer, times(2)).consumeReads(channelBuffer);
}
public void testPartialRead() throws IOException {
byte[] bytes = createMessage(messageLength);
when(channel.read(any(ByteBuffer[].class))).thenAnswer(invocationOnMock -> {
ByteBuffer[] buffers = (ByteBuffer[]) invocationOnMock.getArguments()[0];
buffers[0].put(bytes);
return bytes.length;
});
when(readConsumer.consumeReads(channelBuffer)).thenReturn(0);
assertEquals(messageLength, context.read());
assertEquals(bytes.length, channelBuffer.getIndex());
verify(readConsumer, times(1)).consumeReads(channelBuffer);
when(readConsumer.consumeReads(channelBuffer)).thenReturn(messageLength * 2, 0);
assertEquals(messageLength, context.read());
assertEquals(0, channelBuffer.getIndex());
assertEquals(BigArrays.BYTE_PAGE_SIZE - (bytes.length * 2), channelBuffer.getCapacity());
verify(readConsumer, times(2)).consumeReads(channelBuffer);
}
public void testReadThrowsIOException() throws IOException {
IOException ioException = new IOException();
when(channel.read(any(ByteBuffer[].class))).thenThrow(ioException);
IOException ex = expectThrows(IOException.class, () -> context.read());
assertSame(ioException, ex);
}
public void testReadThrowsIOExceptionMeansReadyForClose() throws IOException {
when(channel.read(any(ByteBuffer[].class))).thenThrow(new IOException());
assertFalse(context.selectorShouldClose());
expectThrows(IOException.class, () -> context.read());
assertTrue(context.selectorShouldClose());
}
public void testReadLessThanZeroMeansReadyForClose() throws IOException {
when(channel.read(any(ByteBuffer[].class))).thenReturn(-1);
assertEquals(0, context.read());
assertTrue(context.selectorShouldClose());
}
public void testCloseClosesChannelBuffer() throws IOException {
Runnable closer = mock(Runnable.class);
Supplier<InboundChannelBuffer.Page> pageSupplier = () -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14), closer);
InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier);
buffer.ensureCapacity(1);
BytesChannelContext context = new BytesChannelContext(channel, readConsumer, buffer);
context.closeFromSelector();
verify(closer).run();
}
public void testWriteFailsIfClosing() {
context.closeChannel();
ByteBuffer[] buffers = {ByteBuffer.wrap(createMessage(10))};
context.sendMessage(buffers, listener);
verify(listener).accept(isNull(Void.class), any(ClosedChannelException.class));
}
public void testSendMessageFromDifferentThreadIsQueuedWithSelector() throws Exception {
ArgumentCaptor<BytesWriteOperation> writeOpCaptor = ArgumentCaptor.forClass(BytesWriteOperation.class);
when(selector.isOnCurrentThread()).thenReturn(false);
ByteBuffer[] buffers = {ByteBuffer.wrap(createMessage(10))};
context.sendMessage(buffers, listener);
verify(selector).queueWrite(writeOpCaptor.capture());
BytesWriteOperation writeOp = writeOpCaptor.getValue();
assertSame(listener, writeOp.getListener());
assertSame(channel, writeOp.getChannel());
assertEquals(buffers[0], writeOp.getBuffersToWrite()[0]);
}
public void testSendMessageFromSameThreadIsQueuedInChannel() {
ArgumentCaptor<BytesWriteOperation> writeOpCaptor = ArgumentCaptor.forClass(BytesWriteOperation.class);
ByteBuffer[] buffers = {ByteBuffer.wrap(createMessage(10))};
context.sendMessage(buffers, listener);
verify(selector).queueWriteInChannelBuffer(writeOpCaptor.capture());
BytesWriteOperation writeOp = writeOpCaptor.getValue();
assertSame(listener, writeOp.getListener());
assertSame(channel, writeOp.getChannel());
assertEquals(buffers[0], writeOp.getBuffersToWrite()[0]);
}
public void testWriteIsQueuedInChannel() {
assertFalse(context.hasQueuedWriteOps());
ByteBuffer[] buffer = {ByteBuffer.allocate(10)};
context.queueWriteOperation(new BytesWriteOperation(channel, buffer, listener));
assertTrue(context.hasQueuedWriteOps());
}
public void testWriteOpsClearedOnClose() throws Exception {
assertFalse(context.hasQueuedWriteOps());
ByteBuffer[] buffer = {ByteBuffer.allocate(10)};
context.queueWriteOperation(new BytesWriteOperation(channel, buffer, listener));
assertTrue(context.hasQueuedWriteOps());
context.closeFromSelector();
verify(selector).executeFailedListener(same(listener), any(ClosedChannelException.class));
assertFalse(context.hasQueuedWriteOps());
}
public void testQueuedWriteIsFlushedInFlushCall() throws Exception {
assertFalse(context.hasQueuedWriteOps());
ByteBuffer[] buffers = {ByteBuffer.allocate(10)};
BytesWriteOperation writeOperation = mock(BytesWriteOperation.class);
context.queueWriteOperation(writeOperation);
assertTrue(context.hasQueuedWriteOps());
when(writeOperation.getBuffersToWrite()).thenReturn(buffers);
when(writeOperation.isFullyFlushed()).thenReturn(true);
when(writeOperation.getListener()).thenReturn(listener);
context.flushChannel();
verify(channel).write(buffers);
verify(selector).executeListener(listener, null);
assertFalse(context.hasQueuedWriteOps());
}
public void testPartialFlush() throws IOException {
assertFalse(context.hasQueuedWriteOps());
BytesWriteOperation writeOperation = mock(BytesWriteOperation.class);
context.queueWriteOperation(writeOperation);
assertTrue(context.hasQueuedWriteOps());
when(writeOperation.isFullyFlushed()).thenReturn(false);
context.flushChannel();
verify(listener, times(0)).accept(null, null);
assertTrue(context.hasQueuedWriteOps());
}
@SuppressWarnings("unchecked")
public void testMultipleWritesPartialFlushes() throws IOException {
assertFalse(context.hasQueuedWriteOps());
BiConsumer<Void, Throwable> listener2 = mock(BiConsumer.class);
BytesWriteOperation writeOperation1 = mock(BytesWriteOperation.class);
BytesWriteOperation writeOperation2 = mock(BytesWriteOperation.class);
when(writeOperation1.getListener()).thenReturn(listener);
when(writeOperation2.getListener()).thenReturn(listener2);
context.queueWriteOperation(writeOperation1);
context.queueWriteOperation(writeOperation2);
assertTrue(context.hasQueuedWriteOps());
when(writeOperation1.isFullyFlushed()).thenReturn(true);
when(writeOperation2.isFullyFlushed()).thenReturn(false);
context.flushChannel();
verify(selector).executeListener(listener, null);
verify(listener2, times(0)).accept(null, null);
assertTrue(context.hasQueuedWriteOps());
when(writeOperation2.isFullyFlushed()).thenReturn(true);
context.flushChannel();
verify(selector).executeListener(listener2, null);
assertFalse(context.hasQueuedWriteOps());
}
public void testWhenIOExceptionThrownListenerIsCalled() throws IOException {
assertFalse(context.hasQueuedWriteOps());
ByteBuffer[] buffers = {ByteBuffer.allocate(10)};
BytesWriteOperation writeOperation = mock(BytesWriteOperation.class);
context.queueWriteOperation(writeOperation);
assertTrue(context.hasQueuedWriteOps());
IOException exception = new IOException();
when(writeOperation.getBuffersToWrite()).thenReturn(buffers);
when(channel.write(buffers)).thenThrow(exception);
when(writeOperation.getListener()).thenReturn(listener);
expectThrows(IOException.class, () -> context.flushChannel());
verify(selector).executeFailedListener(listener, exception);
assertFalse(context.hasQueuedWriteOps());
}
public void testWriteIOExceptionMeansChannelReadyToClose() throws IOException {
ByteBuffer[] buffers = {ByteBuffer.allocate(10)};
BytesWriteOperation writeOperation = mock(BytesWriteOperation.class);
context.queueWriteOperation(writeOperation);
IOException exception = new IOException();
when(writeOperation.getBuffersToWrite()).thenReturn(buffers);
when(channel.write(buffers)).thenThrow(exception);
assertFalse(context.selectorShouldClose());
expectThrows(IOException.class, () -> context.flushChannel());
assertTrue(context.selectorShouldClose());
}
public void initiateCloseSchedulesCloseWithSelector() {
context.closeChannel();
verify(selector).queueChannelClose(channel);
}
private static byte[] createMessage(int length) {
byte[] bytes = new byte[length];
for (int i = 0; i < length; ++i) {
bytes[i] = randomByte();
}
return bytes;
}
}

View File

@ -1,142 +0,0 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.nio;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.test.ESTestCase;
import org.junit.Before;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.function.Supplier;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class BytesReadContextTests extends ESTestCase {
private ReadContext.ReadConsumer readConsumer;
private NioSocketChannel channel;
private BytesReadContext readContext;
private InboundChannelBuffer channelBuffer;
private int messageLength;
@Before
public void init() {
readConsumer = mock(ReadContext.ReadConsumer.class);
messageLength = randomInt(96) + 20;
channel = mock(NioSocketChannel.class);
Supplier<InboundChannelBuffer.Page> pageSupplier = () ->
new InboundChannelBuffer.Page(ByteBuffer.allocate(BigArrays.BYTE_PAGE_SIZE), () -> {});
channelBuffer = new InboundChannelBuffer(pageSupplier);
readContext = new BytesReadContext(channel, readConsumer, channelBuffer);
}
public void testSuccessfulRead() throws IOException {
byte[] bytes = createMessage(messageLength);
when(channel.read(any(ByteBuffer[].class))).thenAnswer(invocationOnMock -> {
ByteBuffer[] buffers = (ByteBuffer[]) invocationOnMock.getArguments()[0];
buffers[0].put(bytes);
return bytes.length;
});
when(readConsumer.consumeReads(channelBuffer)).thenReturn(messageLength, 0);
assertEquals(messageLength, readContext.read());
assertEquals(0, channelBuffer.getIndex());
assertEquals(BigArrays.BYTE_PAGE_SIZE - bytes.length, channelBuffer.getCapacity());
verify(readConsumer, times(2)).consumeReads(channelBuffer);
}
public void testMultipleReadsConsumed() throws IOException {
byte[] bytes = createMessage(messageLength * 2);
when(channel.read(any(ByteBuffer[].class))).thenAnswer(invocationOnMock -> {
ByteBuffer[] buffers = (ByteBuffer[]) invocationOnMock.getArguments()[0];
buffers[0].put(bytes);
return bytes.length;
});
when(readConsumer.consumeReads(channelBuffer)).thenReturn(messageLength, messageLength, 0);
assertEquals(bytes.length, readContext.read());
assertEquals(0, channelBuffer.getIndex());
assertEquals(BigArrays.BYTE_PAGE_SIZE - bytes.length, channelBuffer.getCapacity());
verify(readConsumer, times(3)).consumeReads(channelBuffer);
}
public void testPartialRead() throws IOException {
byte[] bytes = createMessage(messageLength);
when(channel.read(any(ByteBuffer[].class))).thenAnswer(invocationOnMock -> {
ByteBuffer[] buffers = (ByteBuffer[]) invocationOnMock.getArguments()[0];
buffers[0].put(bytes);
return bytes.length;
});
when(readConsumer.consumeReads(channelBuffer)).thenReturn(0, messageLength);
assertEquals(messageLength, readContext.read());
assertEquals(bytes.length, channelBuffer.getIndex());
verify(readConsumer, times(1)).consumeReads(channelBuffer);
when(readConsumer.consumeReads(channelBuffer)).thenReturn(messageLength * 2, 0);
assertEquals(messageLength, readContext.read());
assertEquals(0, channelBuffer.getIndex());
assertEquals(BigArrays.BYTE_PAGE_SIZE - (bytes.length * 2), channelBuffer.getCapacity());
verify(readConsumer, times(3)).consumeReads(channelBuffer);
}
public void testReadThrowsIOException() throws IOException {
IOException ioException = new IOException();
when(channel.read(any(ByteBuffer[].class))).thenThrow(ioException);
IOException ex = expectThrows(IOException.class, () -> readContext.read());
assertSame(ioException, ex);
}
public void closeClosesChannelBuffer() {
InboundChannelBuffer buffer = mock(InboundChannelBuffer.class);
BytesReadContext readContext = new BytesReadContext(channel, readConsumer, buffer);
readContext.close();
verify(buffer).close();
}
private static byte[] createMessage(int length) {
byte[] bytes = new byte[length];
for (int i = 0; i < length; ++i) {
bytes[i] = randomByte();
}
return bytes;
}
}

View File

@ -1,212 +0,0 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.nio;
import org.elasticsearch.test.ESTestCase;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.util.function.BiConsumer;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.isNull;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class BytesWriteContextTests extends ESTestCase {
private SocketSelector selector;
private BiConsumer<Void, Throwable> listener;
private BytesWriteContext writeContext;
private NioSocketChannel channel;
@Before
@SuppressWarnings("unchecked")
public void setUp() throws Exception {
super.setUp();
selector = mock(SocketSelector.class);
listener = mock(BiConsumer.class);
channel = mock(NioSocketChannel.class);
writeContext = new BytesWriteContext(channel);
when(channel.getSelector()).thenReturn(selector);
when(selector.isOnCurrentThread()).thenReturn(true);
}
public void testWriteFailsIfChannelNotWritable() throws Exception {
when(channel.isWritable()).thenReturn(false);
ByteBuffer[] buffers = {ByteBuffer.wrap(generateBytes(10))};
writeContext.sendMessage(buffers, listener);
verify(listener).accept(isNull(Void.class), any(ClosedChannelException.class));
}
public void testSendMessageFromDifferentThreadIsQueuedWithSelector() throws Exception {
ArgumentCaptor<WriteOperation> writeOpCaptor = ArgumentCaptor.forClass(WriteOperation.class);
when(selector.isOnCurrentThread()).thenReturn(false);
when(channel.isWritable()).thenReturn(true);
ByteBuffer[] buffers = {ByteBuffer.wrap(generateBytes(10))};
writeContext.sendMessage(buffers, listener);
verify(selector).queueWrite(writeOpCaptor.capture());
WriteOperation writeOp = writeOpCaptor.getValue();
assertSame(listener, writeOp.getListener());
assertSame(channel, writeOp.getChannel());
assertEquals(buffers[0], writeOp.getByteBuffers()[0]);
}
public void testSendMessageFromSameThreadIsQueuedInChannel() throws Exception {
ArgumentCaptor<WriteOperation> writeOpCaptor = ArgumentCaptor.forClass(WriteOperation.class);
when(channel.isWritable()).thenReturn(true);
ByteBuffer[] buffers = {ByteBuffer.wrap(generateBytes(10))};
writeContext.sendMessage(buffers, listener);
verify(selector).queueWriteInChannelBuffer(writeOpCaptor.capture());
WriteOperation writeOp = writeOpCaptor.getValue();
assertSame(listener, writeOp.getListener());
assertSame(channel, writeOp.getChannel());
assertEquals(buffers[0], writeOp.getByteBuffers()[0]);
}
public void testWriteIsQueuedInChannel() throws Exception {
assertFalse(writeContext.hasQueuedWriteOps());
ByteBuffer[] buffer = {ByteBuffer.allocate(10)};
writeContext.queueWriteOperations(new WriteOperation(channel, buffer, listener));
assertTrue(writeContext.hasQueuedWriteOps());
}
public void testWriteOpsCanBeCleared() throws Exception {
assertFalse(writeContext.hasQueuedWriteOps());
ByteBuffer[] buffer = {ByteBuffer.allocate(10)};
writeContext.queueWriteOperations(new WriteOperation(channel, buffer, listener));
assertTrue(writeContext.hasQueuedWriteOps());
ClosedChannelException e = new ClosedChannelException();
writeContext.clearQueuedWriteOps(e);
verify(selector).executeFailedListener(listener, e);
assertFalse(writeContext.hasQueuedWriteOps());
}
public void testQueuedWriteIsFlushedInFlushCall() throws Exception {
assertFalse(writeContext.hasQueuedWriteOps());
WriteOperation writeOperation = mock(WriteOperation.class);
writeContext.queueWriteOperations(writeOperation);
assertTrue(writeContext.hasQueuedWriteOps());
when(writeOperation.isFullyFlushed()).thenReturn(true);
when(writeOperation.getListener()).thenReturn(listener);
writeContext.flushChannel();
verify(writeOperation).flush();
verify(selector).executeListener(listener, null);
assertFalse(writeContext.hasQueuedWriteOps());
}
public void testPartialFlush() throws IOException {
assertFalse(writeContext.hasQueuedWriteOps());
WriteOperation writeOperation = mock(WriteOperation.class);
writeContext.queueWriteOperations(writeOperation);
assertTrue(writeContext.hasQueuedWriteOps());
when(writeOperation.isFullyFlushed()).thenReturn(false);
writeContext.flushChannel();
verify(listener, times(0)).accept(null, null);
assertTrue(writeContext.hasQueuedWriteOps());
}
@SuppressWarnings("unchecked")
public void testMultipleWritesPartialFlushes() throws IOException {
assertFalse(writeContext.hasQueuedWriteOps());
BiConsumer<Void, Throwable> listener2 = mock(BiConsumer.class);
WriteOperation writeOperation1 = mock(WriteOperation.class);
WriteOperation writeOperation2 = mock(WriteOperation.class);
when(writeOperation1.getListener()).thenReturn(listener);
when(writeOperation2.getListener()).thenReturn(listener2);
writeContext.queueWriteOperations(writeOperation1);
writeContext.queueWriteOperations(writeOperation2);
assertTrue(writeContext.hasQueuedWriteOps());
when(writeOperation1.isFullyFlushed()).thenReturn(true);
when(writeOperation2.isFullyFlushed()).thenReturn(false);
writeContext.flushChannel();
verify(selector).executeListener(listener, null);
verify(listener2, times(0)).accept(null, null);
assertTrue(writeContext.hasQueuedWriteOps());
when(writeOperation2.isFullyFlushed()).thenReturn(true);
writeContext.flushChannel();
verify(selector).executeListener(listener2, null);
assertFalse(writeContext.hasQueuedWriteOps());
}
public void testWhenIOExceptionThrownListenerIsCalled() throws IOException {
assertFalse(writeContext.hasQueuedWriteOps());
WriteOperation writeOperation = mock(WriteOperation.class);
writeContext.queueWriteOperations(writeOperation);
assertTrue(writeContext.hasQueuedWriteOps());
IOException exception = new IOException();
when(writeOperation.flush()).thenThrow(exception);
when(writeOperation.getListener()).thenReturn(listener);
expectThrows(IOException.class, () -> writeContext.flushChannel());
verify(selector).executeFailedListener(listener, exception);
assertFalse(writeContext.hasQueuedWriteOps());
}
private byte[] generateBytes(int n) {
n += 10;
byte[] bytes = new byte[n];
for (int i = 0; i < n; ++i) {
bytes[i] = randomByte();
}
return bytes;
}
}

View File

@ -139,7 +139,7 @@ public class ChannelFactoryTests extends ESTestCase {
@Override @Override
public NioSocketChannel createChannel(SocketSelector selector, SocketChannel channel) throws IOException { public NioSocketChannel createChannel(SocketSelector selector, SocketChannel channel) throws IOException {
NioSocketChannel nioSocketChannel = new NioSocketChannel(channel, selector); NioSocketChannel nioSocketChannel = new NioSocketChannel(channel, selector);
nioSocketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class), mock(BiConsumer.class)); nioSocketChannel.setContexts(mock(ChannelContext.class), mock(BiConsumer.class));
return nioSocketChannel; return nioSocketChannel;
} }

View File

@ -82,7 +82,7 @@ public class NioServerSocketChannelTests extends ESTestCase {
PlainActionFuture<Void> closeFuture = PlainActionFuture.newFuture(); PlainActionFuture<Void> closeFuture = PlainActionFuture.newFuture();
channel.addCloseListener(ActionListener.toBiConsumer(closeFuture)); channel.addCloseListener(ActionListener.toBiConsumer(closeFuture));
channel.close(); selector.queueChannelClose(channel);
closeFuture.actionGet(); closeFuture.actionGet();

View File

@ -35,6 +35,7 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer; import java.util.function.BiConsumer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ -66,7 +67,7 @@ public class NioSocketChannelTests extends ESTestCase {
CountDownLatch latch = new CountDownLatch(1); CountDownLatch latch = new CountDownLatch(1);
NioSocketChannel socketChannel = new DoNotCloseChannel(mock(SocketChannel.class), selector); NioSocketChannel socketChannel = new DoNotCloseChannel(mock(SocketChannel.class), selector);
socketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class), mock(BiConsumer.class)); socketChannel.setContexts(mock(ChannelContext.class), mock(BiConsumer.class));
socketChannel.addCloseListener(ActionListener.toBiConsumer(new ActionListener<Void>() { socketChannel.addCloseListener(ActionListener.toBiConsumer(new ActionListener<Void>() {
@Override @Override
public void onResponse(Void o) { public void onResponse(Void o) {
@ -86,7 +87,45 @@ public class NioSocketChannelTests extends ESTestCase {
PlainActionFuture<Void> closeFuture = PlainActionFuture.newFuture(); PlainActionFuture<Void> closeFuture = PlainActionFuture.newFuture();
socketChannel.addCloseListener(ActionListener.toBiConsumer(closeFuture)); socketChannel.addCloseListener(ActionListener.toBiConsumer(closeFuture));
socketChannel.close(); selector.queueChannelClose(socketChannel);
closeFuture.actionGet();
assertTrue(closedRawChannel.get());
assertFalse(socketChannel.isOpen());
latch.await();
assertTrue(isClosed.get());
}
@SuppressWarnings("unchecked")
public void testCloseContextExceptionDoesNotStopClose() throws Exception {
AtomicBoolean isClosed = new AtomicBoolean(false);
CountDownLatch latch = new CountDownLatch(1);
IOException ioException = new IOException();
NioSocketChannel socketChannel = new DoNotCloseChannel(mock(SocketChannel.class), selector);
ChannelContext context = mock(ChannelContext.class);
doThrow(ioException).when(context).closeFromSelector();
socketChannel.setContexts(context, mock(BiConsumer.class));
socketChannel.addCloseListener(ActionListener.toBiConsumer(new ActionListener<Void>() {
@Override
public void onResponse(Void o) {
isClosed.set(true);
latch.countDown();
}
@Override
public void onFailure(Exception e) {
isClosed.set(true);
latch.countDown();
}
}));
assertTrue(socketChannel.isOpen());
assertFalse(closedRawChannel.get());
assertFalse(isClosed.get());
PlainActionFuture<Void> closeFuture = PlainActionFuture.newFuture();
socketChannel.addCloseListener(ActionListener.toBiConsumer(closeFuture));
selector.queueChannelClose(socketChannel);
closeFuture.actionGet(); closeFuture.actionGet();
assertTrue(closedRawChannel.get()); assertTrue(closedRawChannel.get());
@ -100,7 +139,7 @@ public class NioSocketChannelTests extends ESTestCase {
SocketChannel rawChannel = mock(SocketChannel.class); SocketChannel rawChannel = mock(SocketChannel.class);
when(rawChannel.finishConnect()).thenReturn(true); when(rawChannel.finishConnect()).thenReturn(true);
NioSocketChannel socketChannel = new DoNotCloseChannel(rawChannel, selector); NioSocketChannel socketChannel = new DoNotCloseChannel(rawChannel, selector);
socketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class), mock(BiConsumer.class)); socketChannel.setContexts(mock(ChannelContext.class), mock(BiConsumer.class));
selector.scheduleForRegistration(socketChannel); selector.scheduleForRegistration(socketChannel);
PlainActionFuture<Void> connectFuture = PlainActionFuture.newFuture(); PlainActionFuture<Void> connectFuture = PlainActionFuture.newFuture();
@ -117,7 +156,7 @@ public class NioSocketChannelTests extends ESTestCase {
SocketChannel rawChannel = mock(SocketChannel.class); SocketChannel rawChannel = mock(SocketChannel.class);
when(rawChannel.finishConnect()).thenThrow(new ConnectException()); when(rawChannel.finishConnect()).thenThrow(new ConnectException());
NioSocketChannel socketChannel = new DoNotCloseChannel(rawChannel, selector); NioSocketChannel socketChannel = new DoNotCloseChannel(rawChannel, selector);
socketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class), mock(BiConsumer.class)); socketChannel.setContexts(mock(ChannelContext.class), mock(BiConsumer.class));
selector.scheduleForRegistration(socketChannel); selector.scheduleForRegistration(socketChannel);
PlainActionFuture<Void> connectFuture = PlainActionFuture.newFuture(); PlainActionFuture<Void> connectFuture = PlainActionFuture.newFuture();

View File

@ -28,8 +28,10 @@ import java.nio.channels.CancelledKeyException;
import java.nio.channels.SelectionKey; import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel; import java.nio.channels.SocketChannel;
import java.util.function.BiConsumer; import java.util.function.BiConsumer;
import java.util.function.Supplier;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ -39,7 +41,6 @@ public class SocketEventHandlerTests extends ESTestCase {
private SocketEventHandler handler; private SocketEventHandler handler;
private NioSocketChannel channel; private NioSocketChannel channel;
private ReadContext readContext;
private SocketChannel rawChannel; private SocketChannel rawChannel;
@Before @Before
@ -50,21 +51,37 @@ public class SocketEventHandlerTests extends ESTestCase {
handler = new SocketEventHandler(logger); handler = new SocketEventHandler(logger);
rawChannel = mock(SocketChannel.class); rawChannel = mock(SocketChannel.class);
channel = new DoNotRegisterChannel(rawChannel, socketSelector); channel = new DoNotRegisterChannel(rawChannel, socketSelector);
readContext = mock(ReadContext.class);
when(rawChannel.finishConnect()).thenReturn(true); when(rawChannel.finishConnect()).thenReturn(true);
channel.setContexts(readContext, new BytesWriteContext(channel), exceptionHandler); Supplier<InboundChannelBuffer.Page> pageSupplier = () -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14), () -> {});
InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier);
channel.setContexts(new BytesChannelContext(channel, mock(ChannelContext.ReadConsumer.class), buffer), exceptionHandler);
channel.register(); channel.register();
channel.finishConnect(); channel.finishConnect();
when(socketSelector.isOnCurrentThread()).thenReturn(true); when(socketSelector.isOnCurrentThread()).thenReturn(true);
} }
public void testRegisterCallsContext() throws IOException {
NioSocketChannel channel = mock(NioSocketChannel.class);
ChannelContext channelContext = mock(ChannelContext.class);
when(channel.getContext()).thenReturn(channelContext);
when(channel.getSelectionKey()).thenReturn(new TestSelectionKey(0));
handler.handleRegistration(channel);
verify(channelContext).channelRegistered();
}
public void testRegisterAddsOP_CONNECTAndOP_READInterest() throws IOException { public void testRegisterAddsOP_CONNECTAndOP_READInterest() throws IOException {
handler.handleRegistration(channel); handler.handleRegistration(channel);
assertEquals(SelectionKey.OP_READ | SelectionKey.OP_CONNECT, channel.getSelectionKey().interestOps()); assertEquals(SelectionKey.OP_READ | SelectionKey.OP_CONNECT, channel.getSelectionKey().interestOps());
} }
public void testRegisterWithPendingWritesAddsOP_CONNECTAndOP_READAndOP_WRITEInterest() throws IOException {
channel.getContext().queueWriteOperation(mock(BytesWriteOperation.class));
handler.handleRegistration(channel);
assertEquals(SelectionKey.OP_READ | SelectionKey.OP_CONNECT | SelectionKey.OP_WRITE, channel.getSelectionKey().interestOps());
}
public void testRegistrationExceptionCallsExceptionHandler() throws IOException { public void testRegistrationExceptionCallsExceptionHandler() throws IOException {
CancelledKeyException exception = new CancelledKeyException(); CancelledKeyException exception = new CancelledKeyException();
handler.registrationException(channel, exception); handler.registrationException(channel, exception);
@ -83,68 +100,76 @@ public class SocketEventHandlerTests extends ESTestCase {
verify(exceptionHandler).accept(channel, exception); verify(exceptionHandler).accept(channel, exception);
} }
public void testHandleReadDelegatesToReadContext() throws IOException { public void testHandleReadDelegatesToContext() throws IOException {
when(readContext.read()).thenReturn(1); NioSocketChannel channel = new DoNotRegisterChannel(rawChannel, mock(SocketSelector.class));
ChannelContext context = mock(ChannelContext.class);
channel.setContexts(context, exceptionHandler);
when(context.read()).thenReturn(1);
handler.handleRead(channel); handler.handleRead(channel);
verify(context).read();
verify(readContext).read();
} }
public void testHandleReadMarksChannelForCloseIfPeerClosed() throws IOException { public void testReadExceptionCallsExceptionHandler() {
NioSocketChannel nioSocketChannel = mock(NioSocketChannel.class);
when(nioSocketChannel.getReadContext()).thenReturn(readContext);
when(readContext.read()).thenReturn(-1);
handler.handleRead(nioSocketChannel);
verify(nioSocketChannel).closeFromSelector();
}
public void testReadExceptionCallsExceptionHandler() throws IOException {
IOException exception = new IOException(); IOException exception = new IOException();
handler.readException(channel, exception); handler.readException(channel, exception);
verify(exceptionHandler).accept(channel, exception); verify(exceptionHandler).accept(channel, exception);
} }
@SuppressWarnings("unchecked") public void testWriteExceptionCallsExceptionHandler() {
public void testHandleWriteWithCompleteFlushRemovesOP_WRITEInterest() throws IOException { IOException exception = new IOException();
SelectionKey selectionKey = channel.getSelectionKey(); handler.writeException(channel, exception);
setWriteAndRead(channel); verify(exceptionHandler).accept(channel, exception);
assertEquals(SelectionKey.OP_READ | SelectionKey.OP_WRITE, selectionKey.interestOps());
ByteBuffer[] buffers = {ByteBuffer.allocate(1)};
channel.getWriteContext().queueWriteOperations(new WriteOperation(channel, buffers, mock(BiConsumer.class)));
when(rawChannel.write(buffers[0])).thenReturn(1);
handler.handleWrite(channel);
assertEquals(SelectionKey.OP_READ, selectionKey.interestOps());
} }
@SuppressWarnings("unchecked") public void testPostHandlingCallWillCloseTheChannelIfReady() throws IOException {
public void testHandleWriteWithInCompleteFlushLeavesOP_WRITEInterest() throws IOException { NioSocketChannel channel = mock(NioSocketChannel.class);
SelectionKey selectionKey = channel.getSelectionKey(); ChannelContext context = mock(ChannelContext.class);
setWriteAndRead(channel); when(channel.getSelectionKey()).thenReturn(new TestSelectionKey(0));
assertEquals(SelectionKey.OP_READ | SelectionKey.OP_WRITE, selectionKey.interestOps());
ByteBuffer[] buffers = {ByteBuffer.allocate(1)}; when(channel.getContext()).thenReturn(context);
channel.getWriteContext().queueWriteOperations(new WriteOperation(channel, buffers, mock(BiConsumer.class))); when(context.selectorShouldClose()).thenReturn(true);
handler.postHandling(channel);
when(rawChannel.write(buffers[0])).thenReturn(0); verify(channel).closeFromSelector();
handler.handleWrite(channel);
assertEquals(SelectionKey.OP_READ | SelectionKey.OP_WRITE, selectionKey.interestOps());
} }
public void testHandleWriteWithNoOpsRemovesOP_WRITEInterest() throws IOException { public void testPostHandlingCallWillNotCloseTheChannelIfNotReady() throws IOException {
SelectionKey selectionKey = channel.getSelectionKey(); NioSocketChannel channel = mock(NioSocketChannel.class);
setWriteAndRead(channel); ChannelContext context = mock(ChannelContext.class);
when(channel.getSelectionKey()).thenReturn(new TestSelectionKey(0));
when(channel.getContext()).thenReturn(context);
when(context.selectorShouldClose()).thenReturn(false);
handler.postHandling(channel);
verify(channel, times(0)).closeFromSelector();
}
public void testPostHandlingWillAddWriteIfNecessary() throws IOException {
NioSocketChannel channel = new DoNotRegisterChannel(rawChannel, mock(SocketSelector.class));
channel.setSelectionKey(new TestSelectionKey(SelectionKey.OP_READ));
ChannelContext context = mock(ChannelContext.class);
channel.setContexts(context, null);
when(context.hasQueuedWriteOps()).thenReturn(true);
assertEquals(SelectionKey.OP_READ, channel.getSelectionKey().interestOps());
handler.postHandling(channel);
assertEquals(SelectionKey.OP_READ | SelectionKey.OP_WRITE, channel.getSelectionKey().interestOps()); assertEquals(SelectionKey.OP_READ | SelectionKey.OP_WRITE, channel.getSelectionKey().interestOps());
}
handler.handleWrite(channel); public void testPostHandlingWillRemoveWriteIfNecessary() throws IOException {
NioSocketChannel channel = new DoNotRegisterChannel(rawChannel, mock(SocketSelector.class));
channel.setSelectionKey(new TestSelectionKey(SelectionKey.OP_READ | SelectionKey.OP_WRITE));
ChannelContext context = mock(ChannelContext.class);
channel.setContexts(context, null);
assertEquals(SelectionKey.OP_READ, selectionKey.interestOps()); when(context.hasQueuedWriteOps()).thenReturn(false);
assertEquals(SelectionKey.OP_READ | SelectionKey.OP_WRITE, channel.getSelectionKey().interestOps());
handler.postHandling(channel);
assertEquals(SelectionKey.OP_READ, channel.getSelectionKey().interestOps());
} }
private void setWriteAndRead(NioChannel channel) { private void setWriteAndRead(NioChannel channel) {
@ -152,10 +177,4 @@ public class SocketEventHandlerTests extends ESTestCase {
SelectionKeyUtils.removeConnectInterested(channel); SelectionKeyUtils.removeConnectInterested(channel);
SelectionKeyUtils.setWriteInterested(channel); SelectionKeyUtils.setWriteInterested(channel);
} }
public void testWriteExceptionCallsExceptionHandler() throws IOException {
IOException exception = new IOException();
handler.writeException(channel, exception);
verify(exceptionHandler).accept(channel, exception);
}
} }

View File

@ -49,7 +49,7 @@ public class SocketSelectorTests extends ESTestCase {
private SocketEventHandler eventHandler; private SocketEventHandler eventHandler;
private NioSocketChannel channel; private NioSocketChannel channel;
private TestSelectionKey selectionKey; private TestSelectionKey selectionKey;
private WriteContext writeContext; private ChannelContext channelContext;
private BiConsumer<Void, Throwable> listener; private BiConsumer<Void, Throwable> listener;
private ByteBuffer[] buffers = {ByteBuffer.allocate(1)}; private ByteBuffer[] buffers = {ByteBuffer.allocate(1)};
private Selector rawSelector; private Selector rawSelector;
@ -60,7 +60,7 @@ public class SocketSelectorTests extends ESTestCase {
super.setUp(); super.setUp();
eventHandler = mock(SocketEventHandler.class); eventHandler = mock(SocketEventHandler.class);
channel = mock(NioSocketChannel.class); channel = mock(NioSocketChannel.class);
writeContext = mock(WriteContext.class); channelContext = mock(ChannelContext.class);
listener = mock(BiConsumer.class); listener = mock(BiConsumer.class);
selectionKey = new TestSelectionKey(0); selectionKey = new TestSelectionKey(0);
selectionKey.attach(channel); selectionKey.attach(channel);
@ -71,7 +71,7 @@ public class SocketSelectorTests extends ESTestCase {
when(channel.isOpen()).thenReturn(true); when(channel.isOpen()).thenReturn(true);
when(channel.getSelectionKey()).thenReturn(selectionKey); when(channel.getSelectionKey()).thenReturn(selectionKey);
when(channel.getWriteContext()).thenReturn(writeContext); when(channel.getContext()).thenReturn(channelContext);
when(channel.isConnectComplete()).thenReturn(true); when(channel.isConnectComplete()).thenReturn(true);
when(channel.getSelector()).thenReturn(socketSelector); when(channel.getSelector()).thenReturn(socketSelector);
} }
@ -129,75 +129,71 @@ public class SocketSelectorTests extends ESTestCase {
public void testQueueWriteWhenNotRunning() throws Exception { public void testQueueWriteWhenNotRunning() throws Exception {
socketSelector.close(); socketSelector.close();
socketSelector.queueWrite(new WriteOperation(channel, buffers, listener)); socketSelector.queueWrite(new BytesWriteOperation(channel, buffers, listener));
verify(listener).accept(isNull(Void.class), any(ClosedSelectorException.class)); verify(listener).accept(isNull(Void.class), any(ClosedSelectorException.class));
} }
public void testQueueWriteChannelIsNoLongerWritable() throws Exception { public void testQueueWriteChannelIsClosed() throws Exception {
WriteOperation writeOperation = new WriteOperation(channel, buffers, listener); BytesWriteOperation writeOperation = new BytesWriteOperation(channel, buffers, listener);
socketSelector.queueWrite(writeOperation); socketSelector.queueWrite(writeOperation);
when(channel.isWritable()).thenReturn(false); when(channel.isOpen()).thenReturn(false);
socketSelector.preSelect(); socketSelector.preSelect();
verify(writeContext, times(0)).queueWriteOperations(writeOperation); verify(channelContext, times(0)).queueWriteOperation(writeOperation);
verify(listener).accept(isNull(Void.class), any(ClosedChannelException.class)); verify(listener).accept(isNull(Void.class), any(ClosedChannelException.class));
} }
public void testQueueWriteSelectionKeyThrowsException() throws Exception { public void testQueueWriteSelectionKeyThrowsException() throws Exception {
SelectionKey selectionKey = mock(SelectionKey.class); SelectionKey selectionKey = mock(SelectionKey.class);
WriteOperation writeOperation = new WriteOperation(channel, buffers, listener); BytesWriteOperation writeOperation = new BytesWriteOperation(channel, buffers, listener);
CancelledKeyException cancelledKeyException = new CancelledKeyException(); CancelledKeyException cancelledKeyException = new CancelledKeyException();
socketSelector.queueWrite(writeOperation); socketSelector.queueWrite(writeOperation);
when(channel.isWritable()).thenReturn(true);
when(channel.getSelectionKey()).thenReturn(selectionKey); when(channel.getSelectionKey()).thenReturn(selectionKey);
when(selectionKey.interestOps(anyInt())).thenThrow(cancelledKeyException); when(selectionKey.interestOps(anyInt())).thenThrow(cancelledKeyException);
socketSelector.preSelect(); socketSelector.preSelect();
verify(writeContext, times(0)).queueWriteOperations(writeOperation); verify(channelContext, times(0)).queueWriteOperation(writeOperation);
verify(listener).accept(null, cancelledKeyException); verify(listener).accept(null, cancelledKeyException);
} }
public void testQueueWriteSuccessful() throws Exception { public void testQueueWriteSuccessful() throws Exception {
WriteOperation writeOperation = new WriteOperation(channel, buffers, listener); BytesWriteOperation writeOperation = new BytesWriteOperation(channel, buffers, listener);
socketSelector.queueWrite(writeOperation); socketSelector.queueWrite(writeOperation);
assertTrue((selectionKey.interestOps() & SelectionKey.OP_WRITE) == 0); assertTrue((selectionKey.interestOps() & SelectionKey.OP_WRITE) == 0);
when(channel.isWritable()).thenReturn(true);
socketSelector.preSelect(); socketSelector.preSelect();
verify(writeContext).queueWriteOperations(writeOperation); verify(channelContext).queueWriteOperation(writeOperation);
assertTrue((selectionKey.interestOps() & SelectionKey.OP_WRITE) != 0); assertTrue((selectionKey.interestOps() & SelectionKey.OP_WRITE) != 0);
} }
public void testQueueDirectlyInChannelBufferSuccessful() throws Exception { public void testQueueDirectlyInChannelBufferSuccessful() throws Exception {
WriteOperation writeOperation = new WriteOperation(channel, buffers, listener); BytesWriteOperation writeOperation = new BytesWriteOperation(channel, buffers, listener);
assertTrue((selectionKey.interestOps() & SelectionKey.OP_WRITE) == 0); assertTrue((selectionKey.interestOps() & SelectionKey.OP_WRITE) == 0);
when(channel.isWritable()).thenReturn(true);
socketSelector.queueWriteInChannelBuffer(writeOperation); socketSelector.queueWriteInChannelBuffer(writeOperation);
verify(writeContext).queueWriteOperations(writeOperation); verify(channelContext).queueWriteOperation(writeOperation);
assertTrue((selectionKey.interestOps() & SelectionKey.OP_WRITE) != 0); assertTrue((selectionKey.interestOps() & SelectionKey.OP_WRITE) != 0);
} }
public void testQueueDirectlyInChannelBufferSelectionKeyThrowsException() throws Exception { public void testQueueDirectlyInChannelBufferSelectionKeyThrowsException() throws Exception {
SelectionKey selectionKey = mock(SelectionKey.class); SelectionKey selectionKey = mock(SelectionKey.class);
WriteOperation writeOperation = new WriteOperation(channel, buffers, listener); BytesWriteOperation writeOperation = new BytesWriteOperation(channel, buffers, listener);
CancelledKeyException cancelledKeyException = new CancelledKeyException(); CancelledKeyException cancelledKeyException = new CancelledKeyException();
when(channel.isWritable()).thenReturn(true);
when(channel.getSelectionKey()).thenReturn(selectionKey); when(channel.getSelectionKey()).thenReturn(selectionKey);
when(selectionKey.interestOps(anyInt())).thenThrow(cancelledKeyException); when(selectionKey.interestOps(anyInt())).thenThrow(cancelledKeyException);
socketSelector.queueWriteInChannelBuffer(writeOperation); socketSelector.queueWriteInChannelBuffer(writeOperation);
verify(writeContext, times(0)).queueWriteOperations(writeOperation); verify(channelContext, times(0)).queueWriteOperation(writeOperation);
verify(listener).accept(null, cancelledKeyException); verify(listener).accept(null, cancelledKeyException);
} }
@ -285,6 +281,16 @@ public class SocketSelectorTests extends ESTestCase {
verify(eventHandler).readException(channel, ioException); verify(eventHandler).readException(channel, ioException);
} }
public void testWillCallPostHandleAfterChannelHandling() throws Exception {
selectionKey.setReadyOps(SelectionKey.OP_WRITE | SelectionKey.OP_READ);
socketSelector.processKey(selectionKey);
verify(eventHandler).handleWrite(channel);
verify(eventHandler).handleRead(channel);
verify(eventHandler).postHandling(channel);
}
public void testCleanup() throws Exception { public void testCleanup() throws Exception {
NioSocketChannel unRegisteredChannel = mock(NioSocketChannel.class); NioSocketChannel unRegisteredChannel = mock(NioSocketChannel.class);
@ -292,7 +298,7 @@ public class SocketSelectorTests extends ESTestCase {
socketSelector.preSelect(); socketSelector.preSelect();
socketSelector.queueWrite(new WriteOperation(mock(NioSocketChannel.class), buffers, listener)); socketSelector.queueWrite(new BytesWriteOperation(mock(NioSocketChannel.class), buffers, listener));
socketSelector.scheduleForRegistration(unRegisteredChannel); socketSelector.scheduleForRegistration(unRegisteredChannel);
TestSelectionKey testSelectionKey = new TestSelectionKey(0); TestSelectionKey testSelectionKey = new TestSelectionKey(0);

View File

@ -45,71 +45,58 @@ public class WriteOperationTests extends ESTestCase {
} }
public void testFlush() throws IOException { public void testFullyFlushedMarker() {
ByteBuffer[] buffers = {ByteBuffer.allocate(10)}; ByteBuffer[] buffers = {ByteBuffer.allocate(10)};
WriteOperation writeOp = new WriteOperation(channel, buffers, listener); BytesWriteOperation writeOp = new BytesWriteOperation(channel, buffers, listener);
writeOp.incrementIndex(10);
when(channel.write(any(ByteBuffer[].class))).thenReturn(10);
writeOp.flush();
assertTrue(writeOp.isFullyFlushed()); assertTrue(writeOp.isFullyFlushed());
} }
public void testPartialFlush() throws IOException { public void testPartiallyFlushedMarker() {
ByteBuffer[] buffers = {ByteBuffer.allocate(10)}; ByteBuffer[] buffers = {ByteBuffer.allocate(10)};
WriteOperation writeOp = new WriteOperation(channel, buffers, listener); BytesWriteOperation writeOp = new BytesWriteOperation(channel, buffers, listener);
when(channel.write(any(ByteBuffer[].class))).thenReturn(5); writeOp.incrementIndex(5);
writeOp.flush();
assertFalse(writeOp.isFullyFlushed()); assertFalse(writeOp.isFullyFlushed());
} }
public void testMultipleFlushesWithCompositeBuffer() throws IOException { public void testMultipleFlushesWithCompositeBuffer() throws IOException {
ByteBuffer[] buffers = {ByteBuffer.allocate(10), ByteBuffer.allocate(15), ByteBuffer.allocate(3)}; ByteBuffer[] buffers = {ByteBuffer.allocate(10), ByteBuffer.allocate(15), ByteBuffer.allocate(3)};
WriteOperation writeOp = new WriteOperation(channel, buffers, listener); BytesWriteOperation writeOp = new BytesWriteOperation(channel, buffers, listener);
ArgumentCaptor<ByteBuffer[]> buffersCaptor = ArgumentCaptor.forClass(ByteBuffer[].class); ArgumentCaptor<ByteBuffer[]> buffersCaptor = ArgumentCaptor.forClass(ByteBuffer[].class);
when(channel.write(buffersCaptor.capture())).thenReturn(5) writeOp.incrementIndex(5);
.thenReturn(5)
.thenReturn(2)
.thenReturn(15)
.thenReturn(1);
writeOp.flush();
assertFalse(writeOp.isFullyFlushed()); assertFalse(writeOp.isFullyFlushed());
writeOp.flush(); ByteBuffer[] byteBuffers = writeOp.getBuffersToWrite();
assertFalse(writeOp.isFullyFlushed());
writeOp.flush();
assertFalse(writeOp.isFullyFlushed());
writeOp.flush();
assertFalse(writeOp.isFullyFlushed());
writeOp.flush();
assertTrue(writeOp.isFullyFlushed());
List<ByteBuffer[]> values = buffersCaptor.getAllValues();
ByteBuffer[] byteBuffers = values.get(0);
assertEquals(3, byteBuffers.length);
assertEquals(10, byteBuffers[0].remaining());
byteBuffers = values.get(1);
assertEquals(3, byteBuffers.length); assertEquals(3, byteBuffers.length);
assertEquals(5, byteBuffers[0].remaining()); assertEquals(5, byteBuffers[0].remaining());
byteBuffers = values.get(2); writeOp.incrementIndex(5);
assertFalse(writeOp.isFullyFlushed());
byteBuffers = writeOp.getBuffersToWrite();
assertEquals(2, byteBuffers.length); assertEquals(2, byteBuffers.length);
assertEquals(15, byteBuffers[0].remaining()); assertEquals(15, byteBuffers[0].remaining());
byteBuffers = values.get(3); writeOp.incrementIndex(2);
assertFalse(writeOp.isFullyFlushed());
byteBuffers = writeOp.getBuffersToWrite();
assertEquals(2, byteBuffers.length); assertEquals(2, byteBuffers.length);
assertEquals(13, byteBuffers[0].remaining()); assertEquals(13, byteBuffers[0].remaining());
byteBuffers = values.get(4); writeOp.incrementIndex(15);
assertFalse(writeOp.isFullyFlushed());
byteBuffers = writeOp.getBuffersToWrite();
assertEquals(1, byteBuffers.length); assertEquals(1, byteBuffers.length);
assertEquals(1, byteBuffers[0].remaining()); assertEquals(1, byteBuffers[0].remaining());
writeOp.incrementIndex(1);
assertTrue(writeOp.isFullyFlushed());
byteBuffers = writeOp.getBuffersToWrite();
assertEquals(1, byteBuffers.length);
assertEquals(0, byteBuffers[0].remaining());
} }
} }

View File

@ -33,13 +33,12 @@ import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.nio.AcceptingSelector; import org.elasticsearch.nio.AcceptingSelector;
import org.elasticsearch.nio.AcceptorEventHandler; import org.elasticsearch.nio.AcceptorEventHandler;
import org.elasticsearch.nio.BytesReadContext; import org.elasticsearch.nio.BytesChannelContext;
import org.elasticsearch.nio.BytesWriteContext; import org.elasticsearch.nio.ChannelContext;
import org.elasticsearch.nio.ChannelFactory; import org.elasticsearch.nio.ChannelFactory;
import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioGroup; import org.elasticsearch.nio.NioGroup;
import org.elasticsearch.nio.NioSocketChannel; import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.ReadContext;
import org.elasticsearch.nio.SocketEventHandler; import org.elasticsearch.nio.SocketEventHandler;
import org.elasticsearch.nio.SocketSelector; import org.elasticsearch.nio.SocketSelector;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
@ -72,12 +71,12 @@ public class NioTransport extends TcpTransport {
public static final Setting<Integer> NIO_ACCEPTOR_COUNT = public static final Setting<Integer> NIO_ACCEPTOR_COUNT =
intSetting("transport.nio.acceptor_count", 1, 1, Setting.Property.NodeScope); intSetting("transport.nio.acceptor_count", 1, 1, Setting.Property.NodeScope);
private final PageCacheRecycler pageCacheRecycler; protected final PageCacheRecycler pageCacheRecycler;
private final ConcurrentMap<String, TcpChannelFactory> profileToChannelFactory = newConcurrentMap(); private final ConcurrentMap<String, TcpChannelFactory> profileToChannelFactory = newConcurrentMap();
private volatile NioGroup nioGroup; private volatile NioGroup nioGroup;
private volatile TcpChannelFactory clientChannelFactory; private volatile TcpChannelFactory clientChannelFactory;
NioTransport(Settings settings, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays, protected NioTransport(Settings settings, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays,
PageCacheRecycler pageCacheRecycler, NamedWriteableRegistry namedWriteableRegistry, PageCacheRecycler pageCacheRecycler, NamedWriteableRegistry namedWriteableRegistry,
CircuitBreakerService circuitBreakerService) { CircuitBreakerService circuitBreakerService) {
super("nio", settings, threadPool, bigArrays, circuitBreakerService, namedWriteableRegistry, networkService); super("nio", settings, threadPool, bigArrays, circuitBreakerService, namedWriteableRegistry, networkService);
@ -111,13 +110,13 @@ public class NioTransport extends TcpTransport {
NioTransport.NIO_WORKER_COUNT.get(settings), SocketEventHandler::new); NioTransport.NIO_WORKER_COUNT.get(settings), SocketEventHandler::new);
ProfileSettings clientProfileSettings = new ProfileSettings(settings, "default"); ProfileSettings clientProfileSettings = new ProfileSettings(settings, "default");
clientChannelFactory = new TcpChannelFactory(clientProfileSettings); clientChannelFactory = channelFactory(clientProfileSettings, true);
if (useNetworkServer) { if (useNetworkServer) {
// loop through all profiles and start them up, special handling for default one // loop through all profiles and start them up, special handling for default one
for (ProfileSettings profileSettings : profileSettings) { for (ProfileSettings profileSettings : profileSettings) {
String profileName = profileSettings.profileName; String profileName = profileSettings.profileName;
TcpChannelFactory factory = new TcpChannelFactory(profileSettings); TcpChannelFactory factory = channelFactory(profileSettings, false);
profileToChannelFactory.putIfAbsent(profileName, factory); profileToChannelFactory.putIfAbsent(profileName, factory);
bindServer(profileSettings); bindServer(profileSettings);
} }
@ -144,19 +143,30 @@ public class NioTransport extends TcpTransport {
profileToChannelFactory.clear(); profileToChannelFactory.clear();
} }
private void exceptionCaught(NioSocketChannel channel, Exception exception) { protected void exceptionCaught(NioSocketChannel channel, Exception exception) {
onException((TcpChannel) channel, exception); onException((TcpChannel) channel, exception);
} }
private void acceptChannel(NioSocketChannel channel) { protected void acceptChannel(NioSocketChannel channel) {
serverAcceptedChannel((TcpNioSocketChannel) channel); serverAcceptedChannel((TcpNioSocketChannel) channel);
} }
private class TcpChannelFactory extends ChannelFactory<TcpNioServerSocketChannel, TcpNioSocketChannel> { protected TcpChannelFactory channelFactory(ProfileSettings settings, boolean isClient) {
return new TcpChannelFactoryImpl(settings);
}
protected abstract class TcpChannelFactory extends ChannelFactory<TcpNioServerSocketChannel, TcpNioSocketChannel> {
protected TcpChannelFactory(RawChannelFactory rawChannelFactory) {
super(rawChannelFactory);
}
}
private class TcpChannelFactoryImpl extends TcpChannelFactory {
private final String profileName; private final String profileName;
TcpChannelFactory(TcpTransport.ProfileSettings profileSettings) { private TcpChannelFactoryImpl(ProfileSettings profileSettings) {
super(new RawChannelFactory(profileSettings.tcpNoDelay, super(new RawChannelFactory(profileSettings.tcpNoDelay,
profileSettings.tcpKeepAlive, profileSettings.tcpKeepAlive,
profileSettings.reuseAddress, profileSettings.reuseAddress,
@ -172,10 +182,10 @@ public class NioTransport extends TcpTransport {
Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false); Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close); return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close);
}; };
ReadContext.ReadConsumer nioReadConsumer = channelBuffer -> ChannelContext.ReadConsumer nioReadConsumer = channelBuffer ->
consumeNetworkReads(nioChannel, BytesReference.fromByteBuffers(channelBuffer.sliceBuffersTo(channelBuffer.getIndex()))); consumeNetworkReads(nioChannel, BytesReference.fromByteBuffers(channelBuffer.sliceBuffersTo(channelBuffer.getIndex())));
BytesReadContext readContext = new BytesReadContext(nioChannel, nioReadConsumer, new InboundChannelBuffer(pageSupplier)); BytesChannelContext context = new BytesChannelContext(nioChannel, nioReadConsumer, new InboundChannelBuffer(pageSupplier));
nioChannel.setContexts(readContext, new BytesWriteContext(nioChannel), NioTransport.this::exceptionCaught); nioChannel.setContexts(context, NioTransport.this::exceptionCaught);
return nioChannel; return nioChannel;
} }

View File

@ -38,7 +38,7 @@ public class TcpNioServerSocketChannel extends NioServerSocketChannel implements
private final String profile; private final String profile;
TcpNioServerSocketChannel(String profile, ServerSocketChannel socketChannel, public TcpNioServerSocketChannel(String profile, ServerSocketChannel socketChannel,
ChannelFactory<TcpNioServerSocketChannel, TcpNioSocketChannel> channelFactory, ChannelFactory<TcpNioServerSocketChannel, TcpNioSocketChannel> channelFactory,
AcceptingSelector selector) throws IOException { AcceptingSelector selector) throws IOException {
super(socketChannel, channelFactory, selector); super(socketChannel, channelFactory, selector);
@ -60,6 +60,11 @@ public class TcpNioServerSocketChannel extends NioServerSocketChannel implements
return null; return null;
} }
@Override
public void close() {
getSelector().queueChannelClose(this);
}
@Override @Override
public String getProfile() { public String getProfile() {
return profile; return profile;

View File

@ -33,13 +33,13 @@ public class TcpNioSocketChannel extends NioSocketChannel implements TcpChannel
private final String profile; private final String profile;
TcpNioSocketChannel(String profile, SocketChannel socketChannel, SocketSelector selector) throws IOException { public TcpNioSocketChannel(String profile, SocketChannel socketChannel, SocketSelector selector) throws IOException {
super(socketChannel, selector); super(socketChannel, selector);
this.profile = profile; this.profile = profile;
} }
public void sendMessage(BytesReference reference, ActionListener<Void> listener) { public void sendMessage(BytesReference reference, ActionListener<Void> listener) {
getWriteContext().sendMessage(BytesReference.toByteBuffers(reference), ActionListener.toBiConsumer(listener)); getContext().sendMessage(BytesReference.toByteBuffers(reference), ActionListener.toBiConsumer(listener));
} }
@Override @Override
@ -59,6 +59,11 @@ public class TcpNioSocketChannel extends NioSocketChannel implements TcpChannel
addCloseListener(ActionListener.toBiConsumer(listener)); addCloseListener(ActionListener.toBiConsumer(listener));
} }
@Override
public void close() {
getContext().closeChannel();
}
@Override @Override
public String toString() { public String toString() {
return "TcpNioSocketChannel{" + return "TcpNioSocketChannel{" +

View File

@ -31,14 +31,13 @@ import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.nio.AcceptingSelector; import org.elasticsearch.nio.AcceptingSelector;
import org.elasticsearch.nio.AcceptorEventHandler; import org.elasticsearch.nio.AcceptorEventHandler;
import org.elasticsearch.nio.BytesReadContext; import org.elasticsearch.nio.BytesChannelContext;
import org.elasticsearch.nio.BytesWriteContext; import org.elasticsearch.nio.ChannelContext;
import org.elasticsearch.nio.ChannelFactory; import org.elasticsearch.nio.ChannelFactory;
import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioGroup; import org.elasticsearch.nio.NioGroup;
import org.elasticsearch.nio.NioServerSocketChannel; import org.elasticsearch.nio.NioServerSocketChannel;
import org.elasticsearch.nio.NioSocketChannel; import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.ReadContext;
import org.elasticsearch.nio.SocketSelector; import org.elasticsearch.nio.SocketSelector;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TcpChannel; import org.elasticsearch.transport.TcpChannel;
@ -162,11 +161,10 @@ public class MockNioTransport extends TcpTransport {
Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false); Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close); return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close);
}; };
ReadContext.ReadConsumer nioReadConsumer = channelBuffer -> ChannelContext.ReadConsumer nioReadConsumer = channelBuffer ->
consumeNetworkReads(nioChannel, BytesReference.fromByteBuffers(channelBuffer.sliceBuffersTo(channelBuffer.getIndex()))); consumeNetworkReads(nioChannel, BytesReference.fromByteBuffers(channelBuffer.sliceBuffersTo(channelBuffer.getIndex())));
BytesReadContext readContext = new BytesReadContext(nioChannel, nioReadConsumer, new InboundChannelBuffer(pageSupplier)); BytesChannelContext context = new BytesChannelContext(nioChannel, nioReadConsumer, new InboundChannelBuffer(pageSupplier));
BytesWriteContext writeContext = new BytesWriteContext(nioChannel); nioChannel.setContexts(context, MockNioTransport.this::exceptionCaught);
nioChannel.setContexts(readContext, writeContext, MockNioTransport.this::exceptionCaught);
return nioChannel; return nioChannel;
} }
@ -188,6 +186,11 @@ public class MockNioTransport extends TcpTransport {
this.profile = profile; this.profile = profile;
} }
@Override
public void close() {
getSelector().queueChannelClose(this);
}
@Override @Override
public String getProfile() { public String getProfile() {
return profile; return profile;
@ -224,6 +227,11 @@ public class MockNioTransport extends TcpTransport {
this.profile = profile; this.profile = profile;
} }
@Override
public void close() {
getContext().closeChannel();
}
@Override @Override
public String getProfile() { public String getProfile() {
return profile; return profile;
@ -243,7 +251,7 @@ public class MockNioTransport extends TcpTransport {
@Override @Override
public void sendMessage(BytesReference reference, ActionListener<Void> listener) { public void sendMessage(BytesReference reference, ActionListener<Void> listener) {
getWriteContext().sendMessage(BytesReference.toByteBuffers(reference), ActionListener.toBiConsumer(listener)); getContext().sendMessage(BytesReference.toByteBuffers(reference), ActionListener.toBiConsumer(listener));
} }
} }
} }