Add DirectByteBuffer strategy for transport-nio (#36289)
This is related to #27260. In Elasticsearch all of the messages that we serialize to write to the network are composed of heap bytes. When you read or write to a nio socket in java, the heap memory you passed down must be copied to/from direct memory. The JVM internally does some buffering of the direct memory, however it is essentially unbounded. This commit introduces a simple mechanism of buffering and copying the memory in transport-nio. Each network event loop is given a 64kb DirectByteBuffer. When we go to read we use this buffer and copy the data after the read. Additionally, when we go to write, we copy the data to the direct memory before calling write. 64KB is chosen as this is the default receive buffer size we use for transport-netty4 (NETTY_RECEIVE_PREDICTOR_SIZE). Since we only have one buffer per thread, we could afford larger. However, if we the buffer is large and not all of the data is flushed in a write call, we will do excess copies. This is something we can explore in the future.
This commit is contained in:
parent
fc85c37efc
commit
373c67dd7a
|
@ -38,19 +38,12 @@ public class BytesChannelContext extends SocketChannelContext {
|
|||
|
||||
@Override
|
||||
public int read() throws IOException {
|
||||
if (channelBuffer.getRemaining() == 0) {
|
||||
// Requiring one additional byte will ensure that a new page is allocated.
|
||||
channelBuffer.ensureCapacity(channelBuffer.getCapacity() + 1);
|
||||
}
|
||||
|
||||
int bytesRead = readFromChannel(channelBuffer.sliceBuffersFrom(channelBuffer.getIndex()));
|
||||
int bytesRead = readFromChannel(channelBuffer);
|
||||
|
||||
if (bytesRead == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
channelBuffer.incrementIndex(bytesRead);
|
||||
|
||||
handleReadBytes();
|
||||
|
||||
return bytesRead;
|
||||
|
@ -91,8 +84,7 @@ public class BytesChannelContext extends SocketChannelContext {
|
|||
* Returns a boolean indicating if the operation was fully flushed.
|
||||
*/
|
||||
private boolean singleFlush(FlushOperation flushOperation) throws IOException {
|
||||
int written = flushToChannel(flushOperation.getBuffersToWrite());
|
||||
flushOperation.incrementIndex(written);
|
||||
flushToChannel(flushOperation);
|
||||
return flushOperation.isFullyFlushed();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ package org.elasticsearch.nio;
|
|||
|
||||
import java.io.Closeable;
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.channels.CancelledKeyException;
|
||||
import java.nio.channels.ClosedChannelException;
|
||||
import java.nio.channels.ClosedSelectorException;
|
||||
|
@ -51,6 +52,7 @@ public class NioSelector implements Closeable {
|
|||
private final ConcurrentLinkedQueue<ChannelContext<?>> channelsToRegister = new ConcurrentLinkedQueue<>();
|
||||
private final EventHandler eventHandler;
|
||||
private final Selector selector;
|
||||
private final ByteBuffer ioBuffer;
|
||||
|
||||
private final ReentrantLock runLock = new ReentrantLock();
|
||||
private final CountDownLatch exitedLoop = new CountDownLatch(1);
|
||||
|
@ -65,6 +67,18 @@ public class NioSelector implements Closeable {
|
|||
public NioSelector(EventHandler eventHandler, Selector selector) {
|
||||
this.selector = selector;
|
||||
this.eventHandler = eventHandler;
|
||||
this.ioBuffer = ByteBuffer.allocateDirect(1 << 16);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a cached direct byte buffer for network operations. It is cleared on every get call.
|
||||
*
|
||||
* @return the byte buffer
|
||||
*/
|
||||
public ByteBuffer getIoBuffer() {
|
||||
assertOnSelectorThread();
|
||||
ioBuffer.clear();
|
||||
return ioBuffer;
|
||||
}
|
||||
|
||||
public Selector rawSelector() {
|
||||
|
|
|
@ -44,7 +44,7 @@ import java.util.function.Predicate;
|
|||
*/
|
||||
public abstract class SocketChannelContext extends ChannelContext<SocketChannel> {
|
||||
|
||||
public static final Predicate<NioSocketChannel> ALWAYS_ALLOW_CHANNEL = (c) -> true;
|
||||
protected static final Predicate<NioSocketChannel> ALWAYS_ALLOW_CHANNEL = (c) -> true;
|
||||
|
||||
protected final NioSocketChannel channel;
|
||||
protected final InboundChannelBuffer channelBuffer;
|
||||
|
@ -234,49 +234,113 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
|
|||
return closeNow;
|
||||
}
|
||||
|
||||
|
||||
// When you read or write to a nio socket in java, the heap memory passed down must be copied to/from
|
||||
// direct memory. The JVM internally does some buffering of the direct memory, however we can save space
|
||||
// by reusing a thread-local direct buffer (provided by the NioSelector).
|
||||
//
|
||||
// Each network event loop is given a 64kb DirectByteBuffer. When we read we use this buffer and copy the
|
||||
// data after the read. When we go to write, we copy the data to the direct memory before calling write.
|
||||
// The choice of 64KB is rather arbitrary. We can explore different sizes in the future. However, any
|
||||
// data that is copied to the buffer for a write, but not successfully flushed immediately, must be
|
||||
// copied again on the next call.
|
||||
|
||||
protected int readFromChannel(ByteBuffer buffer) throws IOException {
|
||||
ByteBuffer ioBuffer = getSelector().getIoBuffer();
|
||||
ioBuffer.limit(Math.min(buffer.remaining(), ioBuffer.limit()));
|
||||
int bytesRead;
|
||||
try {
|
||||
int bytesRead = rawChannel.read(buffer);
|
||||
if (bytesRead < 0) {
|
||||
closeNow = true;
|
||||
bytesRead = 0;
|
||||
}
|
||||
return bytesRead;
|
||||
bytesRead = rawChannel.read(ioBuffer);
|
||||
} catch (IOException e) {
|
||||
closeNow = true;
|
||||
throw e;
|
||||
}
|
||||
if (bytesRead < 0) {
|
||||
closeNow = true;
|
||||
return 0;
|
||||
} else {
|
||||
ioBuffer.flip();
|
||||
buffer.put(ioBuffer);
|
||||
return bytesRead;
|
||||
}
|
||||
}
|
||||
|
||||
protected int readFromChannel(ByteBuffer[] buffers) throws IOException {
|
||||
protected int readFromChannel(InboundChannelBuffer channelBuffer) throws IOException {
|
||||
ByteBuffer ioBuffer = getSelector().getIoBuffer();
|
||||
int bytesRead;
|
||||
try {
|
||||
int bytesRead = (int) rawChannel.read(buffers);
|
||||
if (bytesRead < 0) {
|
||||
closeNow = true;
|
||||
bytesRead = 0;
|
||||
}
|
||||
return bytesRead;
|
||||
bytesRead = rawChannel.read(ioBuffer);
|
||||
} catch (IOException e) {
|
||||
closeNow = true;
|
||||
throw e;
|
||||
}
|
||||
if (bytesRead < 0) {
|
||||
closeNow = true;
|
||||
return 0;
|
||||
} else {
|
||||
ioBuffer.flip();
|
||||
channelBuffer.ensureCapacity(channelBuffer.getIndex() + ioBuffer.remaining());
|
||||
ByteBuffer[] buffers = channelBuffer.sliceBuffersFrom(channelBuffer.getIndex());
|
||||
int j = 0;
|
||||
while (j < buffers.length && ioBuffer.remaining() > 0) {
|
||||
ByteBuffer buffer = buffers[j++];
|
||||
copyBytes(ioBuffer, buffer);
|
||||
}
|
||||
channelBuffer.incrementIndex(bytesRead);
|
||||
return bytesRead;
|
||||
}
|
||||
}
|
||||
|
||||
protected int flushToChannel(ByteBuffer buffer) throws IOException {
|
||||
int initialPosition = buffer.position();
|
||||
ByteBuffer ioBuffer = getSelector().getIoBuffer();
|
||||
copyBytes(buffer, ioBuffer);
|
||||
ioBuffer.flip();
|
||||
int bytesWritten;
|
||||
try {
|
||||
return rawChannel.write(buffer);
|
||||
bytesWritten = rawChannel.write(ioBuffer);
|
||||
} catch (IOException e) {
|
||||
closeNow = true;
|
||||
buffer.position(initialPosition);
|
||||
throw e;
|
||||
}
|
||||
buffer.position(initialPosition + bytesWritten);
|
||||
return bytesWritten;
|
||||
}
|
||||
|
||||
protected int flushToChannel(ByteBuffer[] buffers) throws IOException {
|
||||
try {
|
||||
return (int) rawChannel.write(buffers);
|
||||
} catch (IOException e) {
|
||||
closeNow = true;
|
||||
throw e;
|
||||
protected int flushToChannel(FlushOperation flushOperation) throws IOException {
|
||||
ByteBuffer ioBuffer = getSelector().getIoBuffer();
|
||||
|
||||
boolean continueFlush = flushOperation.isFullyFlushed() == false;
|
||||
int totalBytesFlushed = 0;
|
||||
while (continueFlush) {
|
||||
ioBuffer.clear();
|
||||
int j = 0;
|
||||
ByteBuffer[] buffers = flushOperation.getBuffersToWrite();
|
||||
while (j < buffers.length && ioBuffer.remaining() > 0) {
|
||||
ByteBuffer buffer = buffers[j++];
|
||||
copyBytes(buffer, ioBuffer);
|
||||
}
|
||||
ioBuffer.flip();
|
||||
int bytesFlushed;
|
||||
try {
|
||||
bytesFlushed = rawChannel.write(ioBuffer);
|
||||
} catch (IOException e) {
|
||||
closeNow = true;
|
||||
throw e;
|
||||
}
|
||||
flushOperation.incrementIndex(bytesFlushed);
|
||||
totalBytesFlushed += bytesFlushed;
|
||||
continueFlush = ioBuffer.hasRemaining() == false && flushOperation.isFullyFlushed() == false;
|
||||
}
|
||||
return totalBytesFlushed;
|
||||
}
|
||||
|
||||
private void copyBytes(ByteBuffer from, ByteBuffer to) {
|
||||
int nBytesToCopy = Math.min(to.remaining(), from.remaining());
|
||||
int initialLimit = from.limit();
|
||||
from.limit(from.position() + nBytesToCopy);
|
||||
to.put(from);
|
||||
from.limit(initialLimit);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,7 +31,7 @@ import java.util.function.BiConsumer;
|
|||
import java.util.function.Consumer;
|
||||
|
||||
import static org.mockito.Matchers.any;
|
||||
import static org.mockito.Matchers.anyInt;
|
||||
import static org.mockito.Matchers.eq;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
|
@ -64,14 +64,19 @@ public class BytesChannelContextTests extends ESTestCase {
|
|||
context = new BytesChannelContext(channel, selector, mock(Consumer.class), handler, channelBuffer);
|
||||
|
||||
when(selector.isOnCurrentThread()).thenReturn(true);
|
||||
ByteBuffer buffer = ByteBuffer.allocate(1 << 14);
|
||||
when(selector.getIoBuffer()).thenAnswer(invocationOnMock -> {
|
||||
buffer.clear();
|
||||
return buffer;
|
||||
});
|
||||
}
|
||||
|
||||
public void testSuccessfulRead() throws IOException {
|
||||
byte[] bytes = createMessage(messageLength);
|
||||
|
||||
when(rawChannel.read(any(ByteBuffer[].class), anyInt(), anyInt())).thenAnswer(invocationOnMock -> {
|
||||
ByteBuffer[] buffers = (ByteBuffer[]) invocationOnMock.getArguments()[0];
|
||||
buffers[0].put(bytes);
|
||||
when(rawChannel.read(any(ByteBuffer.class))).thenAnswer(invocationOnMock -> {
|
||||
ByteBuffer buffer = (ByteBuffer) invocationOnMock.getArguments()[0];
|
||||
buffer.put(bytes);
|
||||
return bytes.length;
|
||||
});
|
||||
|
||||
|
@ -87,9 +92,9 @@ public class BytesChannelContextTests extends ESTestCase {
|
|||
public void testMultipleReadsConsumed() throws IOException {
|
||||
byte[] bytes = createMessage(messageLength * 2);
|
||||
|
||||
when(rawChannel.read(any(ByteBuffer[].class), anyInt(), anyInt())).thenAnswer(invocationOnMock -> {
|
||||
ByteBuffer[] buffers = (ByteBuffer[]) invocationOnMock.getArguments()[0];
|
||||
buffers[0].put(bytes);
|
||||
when(rawChannel.read(any(ByteBuffer.class))).thenAnswer(invocationOnMock -> {
|
||||
ByteBuffer buffer = (ByteBuffer) invocationOnMock.getArguments()[0];
|
||||
buffer.put(bytes);
|
||||
return bytes.length;
|
||||
});
|
||||
|
||||
|
@ -105,9 +110,9 @@ public class BytesChannelContextTests extends ESTestCase {
|
|||
public void testPartialRead() throws IOException {
|
||||
byte[] bytes = createMessage(messageLength);
|
||||
|
||||
when(rawChannel.read(any(ByteBuffer[].class), anyInt(), anyInt())).thenAnswer(invocationOnMock -> {
|
||||
ByteBuffer[] buffers = (ByteBuffer[]) invocationOnMock.getArguments()[0];
|
||||
buffers[0].put(bytes);
|
||||
when(rawChannel.read(any(ByteBuffer.class))).thenAnswer(invocationOnMock -> {
|
||||
ByteBuffer buffer = (ByteBuffer) invocationOnMock.getArguments()[0];
|
||||
buffer.put(bytes);
|
||||
return bytes.length;
|
||||
});
|
||||
|
||||
|
@ -130,14 +135,14 @@ public class BytesChannelContextTests extends ESTestCase {
|
|||
|
||||
public void testReadThrowsIOException() throws IOException {
|
||||
IOException ioException = new IOException();
|
||||
when(rawChannel.read(any(ByteBuffer[].class), anyInt(), anyInt())).thenThrow(ioException);
|
||||
when(rawChannel.read(any(ByteBuffer.class))).thenThrow(ioException);
|
||||
|
||||
IOException ex = expectThrows(IOException.class, () -> context.read());
|
||||
assertSame(ioException, ex);
|
||||
}
|
||||
|
||||
public void testReadThrowsIOExceptionMeansReadyForClose() throws IOException {
|
||||
when(rawChannel.read(any(ByteBuffer[].class), anyInt(), anyInt())).thenThrow(new IOException());
|
||||
when(rawChannel.read(any(ByteBuffer.class))).thenThrow(new IOException());
|
||||
|
||||
assertFalse(context.selectorShouldClose());
|
||||
expectThrows(IOException.class, () -> context.read());
|
||||
|
@ -145,7 +150,7 @@ public class BytesChannelContextTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testReadLessThanZeroMeansReadyForClose() throws IOException {
|
||||
when(rawChannel.read(any(ByteBuffer[].class), anyInt(), anyInt())).thenReturn(-1L);
|
||||
when(rawChannel.read(any(ByteBuffer.class))).thenReturn(-1);
|
||||
|
||||
assertEquals(0, context.read());
|
||||
|
||||
|
@ -164,11 +169,13 @@ public class BytesChannelContextTests extends ESTestCase {
|
|||
assertTrue(context.readyForFlush());
|
||||
|
||||
when(flushOperation.getBuffersToWrite()).thenReturn(buffers);
|
||||
when(flushOperation.isFullyFlushed()).thenReturn(true);
|
||||
when(flushOperation.isFullyFlushed()).thenReturn(false, true);
|
||||
when(flushOperation.getListener()).thenReturn(listener);
|
||||
context.flushChannel();
|
||||
|
||||
verify(rawChannel).write(buffers, 0, buffers.length);
|
||||
ByteBuffer buffer = buffers[0].duplicate();
|
||||
buffer.flip();
|
||||
verify(rawChannel).write(eq(buffer));
|
||||
verify(selector).executeListener(listener, null);
|
||||
assertFalse(context.readyForFlush());
|
||||
}
|
||||
|
@ -180,7 +187,7 @@ public class BytesChannelContextTests extends ESTestCase {
|
|||
assertTrue(context.readyForFlush());
|
||||
|
||||
when(flushOperation.isFullyFlushed()).thenReturn(false);
|
||||
when(flushOperation.getBuffersToWrite()).thenReturn(new ByteBuffer[0]);
|
||||
when(flushOperation.getBuffersToWrite()).thenReturn(new ByteBuffer[] {ByteBuffer.allocate(3)});
|
||||
context.flushChannel();
|
||||
|
||||
verify(listener, times(0)).accept(null, null);
|
||||
|
@ -194,8 +201,8 @@ public class BytesChannelContextTests extends ESTestCase {
|
|||
BiConsumer<Void, Exception> listener2 = mock(BiConsumer.class);
|
||||
FlushReadyWrite flushOperation1 = mock(FlushReadyWrite.class);
|
||||
FlushReadyWrite flushOperation2 = mock(FlushReadyWrite.class);
|
||||
when(flushOperation1.getBuffersToWrite()).thenReturn(new ByteBuffer[0]);
|
||||
when(flushOperation2.getBuffersToWrite()).thenReturn(new ByteBuffer[0]);
|
||||
when(flushOperation1.getBuffersToWrite()).thenReturn(new ByteBuffer[] {ByteBuffer.allocate(3)});
|
||||
when(flushOperation2.getBuffersToWrite()).thenReturn(new ByteBuffer[] {ByteBuffer.allocate(3)});
|
||||
when(flushOperation1.getListener()).thenReturn(listener);
|
||||
when(flushOperation2.getListener()).thenReturn(listener2);
|
||||
|
||||
|
@ -204,7 +211,7 @@ public class BytesChannelContextTests extends ESTestCase {
|
|||
|
||||
assertTrue(context.readyForFlush());
|
||||
|
||||
when(flushOperation1.isFullyFlushed()).thenReturn(true);
|
||||
when(flushOperation1.isFullyFlushed()).thenReturn(false, true);
|
||||
when(flushOperation2.isFullyFlushed()).thenReturn(false);
|
||||
context.flushChannel();
|
||||
|
||||
|
@ -212,7 +219,7 @@ public class BytesChannelContextTests extends ESTestCase {
|
|||
verify(listener2, times(0)).accept(null, null);
|
||||
assertTrue(context.readyForFlush());
|
||||
|
||||
when(flushOperation2.isFullyFlushed()).thenReturn(true);
|
||||
when(flushOperation2.isFullyFlushed()).thenReturn(false, true);
|
||||
|
||||
context.flushChannel();
|
||||
|
||||
|
@ -231,7 +238,7 @@ public class BytesChannelContextTests extends ESTestCase {
|
|||
|
||||
IOException exception = new IOException();
|
||||
when(flushOperation.getBuffersToWrite()).thenReturn(buffers);
|
||||
when(rawChannel.write(buffers, 0, buffers.length)).thenThrow(exception);
|
||||
when(rawChannel.write(any(ByteBuffer.class))).thenThrow(exception);
|
||||
when(flushOperation.getListener()).thenReturn(listener);
|
||||
expectThrows(IOException.class, () -> context.flushChannel());
|
||||
|
||||
|
@ -246,7 +253,7 @@ public class BytesChannelContextTests extends ESTestCase {
|
|||
|
||||
IOException exception = new IOException();
|
||||
when(flushOperation.getBuffersToWrite()).thenReturn(buffers);
|
||||
when(rawChannel.write(buffers, 0, buffers.length)).thenThrow(exception);
|
||||
when(rawChannel.write(any(ByteBuffer.class))).thenThrow(exception);
|
||||
|
||||
assertFalse(context.selectorShouldClose());
|
||||
expectThrows(IOException.class, () -> context.flushChannel());
|
||||
|
|
|
@ -22,6 +22,7 @@ package org.elasticsearch.nio;
|
|||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.junit.Before;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.stubbing.Answer;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
|
@ -54,6 +55,7 @@ public class SocketChannelContextTests extends ESTestCase {
|
|||
private BiConsumer<Void, Exception> listener;
|
||||
private NioSelector selector;
|
||||
private ReadWriteHandler readWriteHandler;
|
||||
private ByteBuffer ioBuffer = ByteBuffer.allocate(1024);
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@Before
|
||||
|
@ -71,6 +73,10 @@ public class SocketChannelContextTests extends ESTestCase {
|
|||
context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, channelBuffer);
|
||||
|
||||
when(selector.isOnCurrentThread()).thenReturn(true);
|
||||
when(selector.getIoBuffer()).thenAnswer(invocationOnMock -> {
|
||||
ioBuffer.clear();
|
||||
return ioBuffer;
|
||||
});
|
||||
}
|
||||
|
||||
public void testIOExceptionSetIfEncountered() throws IOException {
|
||||
|
@ -90,7 +96,6 @@ public class SocketChannelContextTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testSignalWhenPeerClosed() throws IOException {
|
||||
when(rawChannel.read(any(ByteBuffer[].class), anyInt(), anyInt())).thenReturn(-1L);
|
||||
when(rawChannel.read(any(ByteBuffer.class))).thenReturn(-1);
|
||||
assertFalse(context.closeNow());
|
||||
context.read();
|
||||
|
@ -289,6 +294,153 @@ public class SocketChannelContextTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testReadToBufferLimitsToPassedBuffer() throws IOException {
|
||||
ByteBuffer buffer = ByteBuffer.allocate(10);
|
||||
when(rawChannel.read(any(ByteBuffer.class))).thenAnswer(completelyFillBufferAnswer());
|
||||
|
||||
int bytesRead = context.readFromChannel(buffer);
|
||||
assertEquals(bytesRead, 10);
|
||||
assertEquals(0, buffer.remaining());
|
||||
}
|
||||
|
||||
public void testReadToBufferHandlesIOException() throws IOException {
|
||||
when(rawChannel.read(any(ByteBuffer.class))).thenThrow(new IOException());
|
||||
|
||||
expectThrows(IOException.class, () -> context.readFromChannel(ByteBuffer.allocate(10)));
|
||||
assertTrue(context.closeNow());
|
||||
}
|
||||
|
||||
public void testReadToBufferHandlesEOF() throws IOException {
|
||||
when(rawChannel.read(any(ByteBuffer.class))).thenReturn(-1);
|
||||
|
||||
context.readFromChannel(ByteBuffer.allocate(10));
|
||||
assertTrue(context.closeNow());
|
||||
}
|
||||
|
||||
public void testReadToChannelBufferWillReadAsMuchAsIOBufferAllows() throws IOException {
|
||||
when(rawChannel.read(any(ByteBuffer.class))).thenAnswer(completelyFillBufferAnswer());
|
||||
|
||||
InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance();
|
||||
int bytesRead = context.readFromChannel(channelBuffer);
|
||||
assertEquals(ioBuffer.capacity(), bytesRead);
|
||||
assertEquals(ioBuffer.capacity(), channelBuffer.getIndex());
|
||||
}
|
||||
|
||||
public void testReadToChannelBufferHandlesIOException() throws IOException {
|
||||
when(rawChannel.read(any(ByteBuffer.class))).thenThrow(new IOException());
|
||||
|
||||
InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance();
|
||||
expectThrows(IOException.class, () -> context.readFromChannel(channelBuffer));
|
||||
assertTrue(context.closeNow());
|
||||
assertEquals(0, channelBuffer.getIndex());
|
||||
}
|
||||
|
||||
public void testReadToChannelBufferHandlesEOF() throws IOException {
|
||||
when(rawChannel.read(any(ByteBuffer.class))).thenReturn(-1);
|
||||
|
||||
InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance();
|
||||
context.readFromChannel(channelBuffer);
|
||||
assertTrue(context.closeNow());
|
||||
assertEquals(0, channelBuffer.getIndex());
|
||||
}
|
||||
|
||||
public void testFlushBufferHandlesPartialFlush() throws IOException {
|
||||
int bytesToConsume = 3;
|
||||
when(rawChannel.write(any(ByteBuffer.class))).thenAnswer(consumeBufferAnswer(bytesToConsume));
|
||||
|
||||
ByteBuffer buffer = ByteBuffer.allocate(10);
|
||||
context.flushToChannel(buffer);
|
||||
assertEquals(10 - bytesToConsume, buffer.remaining());
|
||||
}
|
||||
|
||||
public void testFlushBufferHandlesFullFlush() throws IOException {
|
||||
int bytesToConsume = 10;
|
||||
when(rawChannel.write(any(ByteBuffer.class))).thenAnswer(consumeBufferAnswer(bytesToConsume));
|
||||
|
||||
ByteBuffer buffer = ByteBuffer.allocate(10);
|
||||
context.flushToChannel(buffer);
|
||||
assertEquals(0, buffer.remaining());
|
||||
}
|
||||
|
||||
public void testFlushBufferHandlesIOException() throws IOException {
|
||||
when(rawChannel.write(any(ByteBuffer.class))).thenThrow(new IOException());
|
||||
|
||||
ByteBuffer buffer = ByteBuffer.allocate(10);
|
||||
expectThrows(IOException.class, () -> context.flushToChannel(buffer));
|
||||
assertTrue(context.closeNow());
|
||||
assertEquals(10, buffer.remaining());
|
||||
}
|
||||
|
||||
public void testFlushBuffersHandlesZeroFlush() throws IOException {
|
||||
when(rawChannel.write(any(ByteBuffer.class))).thenAnswer(consumeBufferAnswer(0));
|
||||
|
||||
ByteBuffer[] buffers = {ByteBuffer.allocate(1023), ByteBuffer.allocate(1023)};
|
||||
FlushOperation flushOperation = new FlushOperation(buffers, listener);
|
||||
context.flushToChannel(flushOperation);
|
||||
assertEquals(2, flushOperation.getBuffersToWrite().length);
|
||||
assertEquals(0, flushOperation.getBuffersToWrite()[0].position());
|
||||
}
|
||||
|
||||
public void testFlushBuffersHandlesPartialFlush() throws IOException {
|
||||
AtomicBoolean first = new AtomicBoolean(true);
|
||||
when(rawChannel.write(any(ByteBuffer.class))).thenAnswer(invocationOnMock -> {
|
||||
if (first.compareAndSet(true, false)) {
|
||||
return consumeBufferAnswer(1024).answer(invocationOnMock);
|
||||
} else {
|
||||
return consumeBufferAnswer(3).answer(invocationOnMock);
|
||||
}
|
||||
});
|
||||
|
||||
ByteBuffer[] buffers = {ByteBuffer.allocate(1023), ByteBuffer.allocate(1023)};
|
||||
FlushOperation flushOperation = new FlushOperation(buffers, listener);
|
||||
context.flushToChannel(flushOperation);
|
||||
assertEquals(1, flushOperation.getBuffersToWrite().length);
|
||||
assertEquals(4, flushOperation.getBuffersToWrite()[0].position());
|
||||
}
|
||||
|
||||
public void testFlushBuffersHandlesFullFlush() throws IOException {
|
||||
AtomicBoolean first = new AtomicBoolean(true);
|
||||
when(rawChannel.write(any(ByteBuffer.class))).thenAnswer(invocationOnMock -> {
|
||||
if (first.compareAndSet(true, false)) {
|
||||
return consumeBufferAnswer(1024).answer(invocationOnMock);
|
||||
} else {
|
||||
return consumeBufferAnswer(1022).answer(invocationOnMock);
|
||||
}
|
||||
});
|
||||
|
||||
ByteBuffer[] buffers = {ByteBuffer.allocate(1023), ByteBuffer.allocate(1023)};
|
||||
FlushOperation flushOperation = new FlushOperation(buffers, listener);
|
||||
context.flushToChannel(flushOperation);
|
||||
assertTrue(flushOperation.isFullyFlushed());
|
||||
}
|
||||
|
||||
public void testFlushBuffersHandlesIOException() throws IOException {
|
||||
when(rawChannel.write(any(ByteBuffer.class))).thenThrow(new IOException());
|
||||
|
||||
ByteBuffer[] buffers = {ByteBuffer.allocate(10), ByteBuffer.allocate(10)};
|
||||
FlushOperation flushOperation = new FlushOperation(buffers, listener);
|
||||
expectThrows(IOException.class, () -> context.flushToChannel(flushOperation));
|
||||
assertTrue(context.closeNow());
|
||||
}
|
||||
|
||||
public void testFlushBuffersHandlesIOExceptionSecondTimeThroughLoop() throws IOException {
|
||||
AtomicBoolean first = new AtomicBoolean(true);
|
||||
when(rawChannel.write(any(ByteBuffer.class))).thenAnswer(invocationOnMock -> {
|
||||
if (first.compareAndSet(true, false)) {
|
||||
return consumeBufferAnswer(1024).answer(invocationOnMock);
|
||||
} else {
|
||||
throw new IOException();
|
||||
}
|
||||
});
|
||||
|
||||
ByteBuffer[] buffers = {ByteBuffer.allocate(1023), ByteBuffer.allocate(1023)};
|
||||
FlushOperation flushOperation = new FlushOperation(buffers, listener);
|
||||
expectThrows(IOException.class, () -> context.flushToChannel(flushOperation));
|
||||
assertTrue(context.closeNow());
|
||||
assertEquals(1, flushOperation.getBuffersToWrite().length);
|
||||
assertEquals(1, flushOperation.getBuffersToWrite()[0].position());
|
||||
}
|
||||
|
||||
private static class TestSocketChannelContext extends SocketChannelContext {
|
||||
|
||||
private TestSocketChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler,
|
||||
|
@ -305,8 +457,8 @@ public class SocketChannelContextTests extends ESTestCase {
|
|||
@Override
|
||||
public int read() throws IOException {
|
||||
if (randomBoolean()) {
|
||||
ByteBuffer[] byteBuffers = {ByteBuffer.allocate(10)};
|
||||
return readFromChannel(byteBuffers);
|
||||
InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance();
|
||||
return readFromChannel(channelBuffer);
|
||||
} else {
|
||||
return readFromChannel(ByteBuffer.allocate(10));
|
||||
}
|
||||
|
@ -316,7 +468,7 @@ public class SocketChannelContextTests extends ESTestCase {
|
|||
public void flushChannel() throws IOException {
|
||||
if (randomBoolean()) {
|
||||
ByteBuffer[] byteBuffers = {ByteBuffer.allocate(10)};
|
||||
flushToChannel(byteBuffers);
|
||||
flushToChannel(new FlushOperation(byteBuffers, (v, e) -> {}));
|
||||
} else {
|
||||
flushToChannel(ByteBuffer.allocate(10));
|
||||
}
|
||||
|
@ -345,4 +497,23 @@ public class SocketChannelContextTests extends ESTestCase {
|
|||
}
|
||||
return bytes;
|
||||
}
|
||||
|
||||
private Answer<Integer> completelyFillBufferAnswer() {
|
||||
return invocationOnMock -> {
|
||||
ByteBuffer b = (ByteBuffer) invocationOnMock.getArguments()[0];
|
||||
int bytesRead = b.remaining();
|
||||
while (b.hasRemaining()) {
|
||||
b.put((byte) 1);
|
||||
}
|
||||
return bytesRead;
|
||||
};
|
||||
}
|
||||
|
||||
private Answer<Object> consumeBufferAnswer(int bytesToConsume) {
|
||||
return invocationOnMock -> {
|
||||
ByteBuffer b = (ByteBuffer) invocationOnMock.getArguments()[0];
|
||||
b.position(b.position() + bytesToConsume);
|
||||
return bytesToConsume;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,8 +10,8 @@ import org.elasticsearch.common.util.BigArrays;
|
|||
import org.elasticsearch.nio.BytesWriteHandler;
|
||||
import org.elasticsearch.nio.FlushReadyWrite;
|
||||
import org.elasticsearch.nio.InboundChannelBuffer;
|
||||
import org.elasticsearch.nio.NioSocketChannel;
|
||||
import org.elasticsearch.nio.NioSelector;
|
||||
import org.elasticsearch.nio.NioSocketChannel;
|
||||
import org.elasticsearch.nio.WriteOperation;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.junit.Before;
|
||||
|
@ -68,12 +68,17 @@ public class SSLChannelContextTests extends ESTestCase {
|
|||
when(selector.isOnCurrentThread()).thenReturn(true);
|
||||
when(sslDriver.getNetworkReadBuffer()).thenReturn(readBuffer);
|
||||
when(sslDriver.getNetworkWriteBuffer()).thenReturn(writeBuffer);
|
||||
ByteBuffer buffer = ByteBuffer.allocate(1 << 14);
|
||||
when(selector.getIoBuffer()).thenAnswer(invocationOnMock -> {
|
||||
buffer.clear();
|
||||
return buffer;
|
||||
});
|
||||
}
|
||||
|
||||
public void testSuccessfulRead() throws IOException {
|
||||
byte[] bytes = createMessage(messageLength);
|
||||
|
||||
when(rawChannel.read(same(readBuffer))).thenReturn(bytes.length);
|
||||
when(rawChannel.read(any(ByteBuffer.class))).thenReturn(bytes.length);
|
||||
doAnswer(getAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer);
|
||||
|
||||
when(readConsumer.apply(channelBuffer)).thenReturn(messageLength, 0);
|
||||
|
@ -88,7 +93,7 @@ public class SSLChannelContextTests extends ESTestCase {
|
|||
public void testMultipleReadsConsumed() throws IOException {
|
||||
byte[] bytes = createMessage(messageLength * 2);
|
||||
|
||||
when(rawChannel.read(same(readBuffer))).thenReturn(bytes.length);
|
||||
when(rawChannel.read(any(ByteBuffer.class))).thenReturn(bytes.length);
|
||||
doAnswer(getAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer);
|
||||
|
||||
when(readConsumer.apply(channelBuffer)).thenReturn(messageLength, messageLength, 0);
|
||||
|
@ -103,7 +108,7 @@ public class SSLChannelContextTests extends ESTestCase {
|
|||
public void testPartialRead() throws IOException {
|
||||
byte[] bytes = createMessage(messageLength);
|
||||
|
||||
when(rawChannel.read(same(readBuffer))).thenReturn(bytes.length);
|
||||
when(rawChannel.read(any(ByteBuffer.class))).thenReturn(bytes.length);
|
||||
doAnswer(getAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer);
|
||||
|
||||
|
||||
|
@ -212,7 +217,7 @@ public class SSLChannelContextTests extends ESTestCase {
|
|||
context.flushChannel();
|
||||
|
||||
verify(sslDriver, times(2)).nonApplicationWrite();
|
||||
verify(rawChannel, times(2)).write(sslDriver.getNetworkWriteBuffer());
|
||||
verify(rawChannel, times(2)).write(same(selector.getIoBuffer()));
|
||||
}
|
||||
|
||||
public void testNonAppWritesStopIfBufferNotFullyFlushed() throws Exception {
|
||||
|
@ -223,7 +228,7 @@ public class SSLChannelContextTests extends ESTestCase {
|
|||
context.flushChannel();
|
||||
|
||||
verify(sslDriver, times(1)).nonApplicationWrite();
|
||||
verify(rawChannel, times(1)).write(sslDriver.getNetworkWriteBuffer());
|
||||
verify(rawChannel, times(1)).write(same(selector.getIoBuffer()));
|
||||
}
|
||||
|
||||
public void testQueuedWriteIsFlushedInFlushCall() throws Exception {
|
||||
|
@ -240,7 +245,7 @@ public class SSLChannelContextTests extends ESTestCase {
|
|||
context.flushChannel();
|
||||
|
||||
verify(flushOperation).incrementIndex(10);
|
||||
verify(rawChannel, times(1)).write(sslDriver.getNetworkWriteBuffer());
|
||||
verify(rawChannel, times(1)).write(same(selector.getIoBuffer()));
|
||||
verify(selector).executeListener(listener, null);
|
||||
assertFalse(context.readyForFlush());
|
||||
}
|
||||
|
@ -258,8 +263,7 @@ public class SSLChannelContextTests extends ESTestCase {
|
|||
when(flushOperation.isFullyFlushed()).thenReturn(false, false);
|
||||
context.flushChannel();
|
||||
|
||||
verify(flushOperation).incrementIndex(5);
|
||||
verify(rawChannel, times(1)).write(sslDriver.getNetworkWriteBuffer());
|
||||
verify(rawChannel, times(1)).write(same(selector.getIoBuffer()));
|
||||
verify(selector, times(0)).executeListener(listener, null);
|
||||
assertTrue(context.readyForFlush());
|
||||
}
|
||||
|
@ -287,7 +291,7 @@ public class SSLChannelContextTests extends ESTestCase {
|
|||
context.flushChannel();
|
||||
|
||||
verify(flushOperation1, times(2)).incrementIndex(5);
|
||||
verify(rawChannel, times(3)).write(sslDriver.getNetworkWriteBuffer());
|
||||
verify(rawChannel, times(3)).write(same(selector.getIoBuffer()));
|
||||
verify(selector).executeListener(listener, null);
|
||||
verify(selector, times(0)).executeListener(listener2, null);
|
||||
assertTrue(context.readyForFlush());
|
||||
|
@ -304,7 +308,7 @@ public class SSLChannelContextTests extends ESTestCase {
|
|||
when(sslDriver.hasFlushPending()).thenReturn(false, false);
|
||||
when(sslDriver.readyForApplicationWrites()).thenReturn(true);
|
||||
when(sslDriver.applicationWrite(buffers)).thenReturn(5);
|
||||
when(rawChannel.write(sslDriver.getNetworkWriteBuffer())).thenThrow(exception);
|
||||
when(rawChannel.write(any(ByteBuffer.class))).thenThrow(exception);
|
||||
when(flushOperation.isFullyFlushed()).thenReturn(false);
|
||||
expectThrows(IOException.class, () -> context.flushChannel());
|
||||
|
||||
|
@ -317,7 +321,7 @@ public class SSLChannelContextTests extends ESTestCase {
|
|||
when(sslDriver.hasFlushPending()).thenReturn(true);
|
||||
when(sslDriver.needsNonApplicationWrite()).thenReturn(true);
|
||||
when(sslDriver.readyForApplicationWrites()).thenReturn(false);
|
||||
when(rawChannel.write(sslDriver.getNetworkWriteBuffer())).thenThrow(new IOException());
|
||||
when(rawChannel.write(any(ByteBuffer.class))).thenThrow(new IOException());
|
||||
|
||||
assertFalse(context.selectorShouldClose());
|
||||
expectThrows(IOException.class, () -> context.flushChannel());
|
||||
|
|
Loading…
Reference in New Issue