Revert "Remove dedicated SSL network write buffer (#41283)"

This reverts commit f65a86c258.
This commit is contained in:
Tim Brooks 2019-04-25 18:39:25 -06:00
parent c4cb0507b4
commit 1f8ff052a1
No known key found for this signature in database
GPG Key ID: C2AA3BB91A889E77
22 changed files with 330 additions and 481 deletions

View File

@ -25,8 +25,6 @@ import java.util.function.BiConsumer;
public class FlushOperation {
private static final ByteBuffer[] EMPTY_ARRAY = new ByteBuffer[0];
private final BiConsumer<Void, Exception> listener;
private final ByteBuffer[] buffers;
private final int[] offsets;
@ -63,38 +61,19 @@ public class FlushOperation {
}
public ByteBuffer[] getBuffersToWrite() {
return getBuffersToWrite(length);
}
public ByteBuffer[] getBuffersToWrite(int maxBytes) {
final int index = Arrays.binarySearch(offsets, internalIndex);
final int offsetIndex = index < 0 ? (-(index + 1)) - 1 : index;
final int finalIndex = Arrays.binarySearch(offsets, Math.min(internalIndex + maxBytes, length));
final int finalOffsetIndex = finalIndex < 0 ? (-(finalIndex + 1)) - 1 : finalIndex;
int offsetIndex = index < 0 ? (-(index + 1)) - 1 : index;
int nBuffers = (finalOffsetIndex - offsetIndex) + 1;
ByteBuffer[] postIndexBuffers = new ByteBuffer[buffers.length - offsetIndex];
int firstBufferPosition = internalIndex - offsets[offsetIndex];
ByteBuffer firstBuffer = buffers[offsetIndex].duplicate();
firstBuffer.position(firstBufferPosition);
if (nBuffers == 1 && firstBuffer.remaining() == 0) {
return EMPTY_ARRAY;
}
ByteBuffer[] postIndexBuffers = new ByteBuffer[nBuffers];
firstBuffer.position(internalIndex - offsets[offsetIndex]);
postIndexBuffers[0] = firstBuffer;
int finalOffset = offsetIndex + nBuffers;
int nBytes = firstBuffer.remaining();
int j = 1;
for (int i = (offsetIndex + 1); i < finalOffset; ++i) {
ByteBuffer buffer = buffers[i].duplicate();
nBytes += buffer.remaining();
postIndexBuffers[j++] = buffer;
for (int i = (offsetIndex + 1); i < buffers.length; ++i) {
postIndexBuffers[j++] = buffers[i].duplicate();
}
int excessBytes = Math.max(0, nBytes - maxBytes);
ByteBuffer lastBuffer = postIndexBuffers[postIndexBuffers.length - 1];
lastBuffer.limit(lastBuffer.limit() - excessBytes);
return postIndexBuffers;
}
}

View File

@ -27,7 +27,7 @@ public class FlushReadyWrite extends FlushOperation implements WriteOperation {
private final SocketChannelContext channelContext;
private final ByteBuffer[] buffers;
public FlushReadyWrite(SocketChannelContext channelContext, ByteBuffer[] buffers, BiConsumer<Void, Exception> listener) {
FlushReadyWrite(SocketChannelContext channelContext, ByteBuffer[] buffers, BiConsumer<Void, Exception> listener) {
super(buffers, listener);
this.channelContext = channelContext;
this.buffers = buffers;

View File

@ -19,6 +19,7 @@
package org.elasticsearch.nio;
import org.elasticsearch.common.util.concurrent.AbstractRefCounted;
import org.elasticsearch.nio.utils.ExceptionsHelper;
import java.nio.ByteBuffer;
@ -139,11 +140,11 @@ public final class InboundChannelBuffer implements AutoCloseable {
ByteBuffer[] buffers = new ByteBuffer[pageCount];
Iterator<Page> pageIterator = pages.iterator();
ByteBuffer firstBuffer = pageIterator.next().byteBuffer().duplicate();
ByteBuffer firstBuffer = pageIterator.next().byteBuffer.duplicate();
firstBuffer.position(firstBuffer.position() + offset);
buffers[0] = firstBuffer;
for (int i = 1; i < buffers.length; i++) {
buffers[i] = pageIterator.next().byteBuffer().duplicate();
buffers[i] = pageIterator.next().byteBuffer.duplicate();
}
if (finalLimit != 0) {
buffers[buffers.length - 1].limit(finalLimit);
@ -179,14 +180,14 @@ public final class InboundChannelBuffer implements AutoCloseable {
Page[] pages = new Page[pageCount];
Iterator<Page> pageIterator = this.pages.iterator();
Page firstPage = pageIterator.next().duplicate();
ByteBuffer firstBuffer = firstPage.byteBuffer();
ByteBuffer firstBuffer = firstPage.byteBuffer;
firstBuffer.position(firstBuffer.position() + offset);
pages[0] = firstPage;
for (int i = 1; i < pages.length; i++) {
pages[i] = pageIterator.next().duplicate();
}
if (finalLimit != 0) {
pages[pages.length - 1].byteBuffer().limit(finalLimit);
pages[pages.length - 1].byteBuffer.limit(finalLimit);
}
return pages;
@ -216,9 +217,9 @@ public final class InboundChannelBuffer implements AutoCloseable {
ByteBuffer[] buffers = new ByteBuffer[pages.size() - pageIndex];
Iterator<Page> pageIterator = pages.descendingIterator();
for (int i = buffers.length - 1; i > 0; --i) {
buffers[i] = pageIterator.next().byteBuffer().duplicate();
buffers[i] = pageIterator.next().byteBuffer.duplicate();
}
ByteBuffer firstPostIndexBuffer = pageIterator.next().byteBuffer().duplicate();
ByteBuffer firstPostIndexBuffer = pageIterator.next().byteBuffer.duplicate();
firstPostIndexBuffer.position(firstPostIndexBuffer.position() + indexInPage);
buffers[0] = firstPostIndexBuffer;
@ -267,4 +268,53 @@ public final class InboundChannelBuffer implements AutoCloseable {
private int indexInPage(long index) {
return (int) (index & PAGE_MASK);
}
public static class Page implements AutoCloseable {
private final ByteBuffer byteBuffer;
// This is reference counted as some implementations want to retain the byte pages by calling
// sliceAndRetainPagesTo. With reference counting we can increment the reference count, return the
// pages, and safely close them when this channel buffer is done with them. The reference count
// would be 1 at that point, meaning that the pages will remain until the implementation closes
// theirs.
private final RefCountedCloseable refCountedCloseable;
public Page(ByteBuffer byteBuffer, Runnable closeable) {
this(byteBuffer, new RefCountedCloseable(closeable));
}
private Page(ByteBuffer byteBuffer, RefCountedCloseable refCountedCloseable) {
this.byteBuffer = byteBuffer;
this.refCountedCloseable = refCountedCloseable;
}
private Page duplicate() {
refCountedCloseable.incRef();
return new Page(byteBuffer.duplicate(), refCountedCloseable);
}
public ByteBuffer getByteBuffer() {
return byteBuffer;
}
@Override
public void close() {
refCountedCloseable.decRef();
}
private static class RefCountedCloseable extends AbstractRefCounted {
private final Runnable closeable;
private RefCountedCloseable(Runnable closeable) {
super("byte array page");
this.closeable = closeable;
}
@Override
protected void closeInternal() {
closeable.run();
}
}
}
}

View File

@ -1,89 +0,0 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.nio;
import org.elasticsearch.common.util.concurrent.AbstractRefCounted;
import java.io.Closeable;
import java.nio.ByteBuffer;
public class Page implements Closeable {
private final ByteBuffer byteBuffer;
// This is reference counted as some implementations want to retain the byte pages by calling
// duplicate. With reference counting we can increment the reference count, return a new page,
// and safely close the pages independently. The closeable will not be called until each page is
// released.
private final RefCountedCloseable refCountedCloseable;
public Page(ByteBuffer byteBuffer) {
this(byteBuffer, () -> {});
}
public Page(ByteBuffer byteBuffer, Runnable closeable) {
this(byteBuffer, new RefCountedCloseable(closeable));
}
private Page(ByteBuffer byteBuffer, RefCountedCloseable refCountedCloseable) {
this.byteBuffer = byteBuffer;
this.refCountedCloseable = refCountedCloseable;
}
/**
* Duplicates this page and increments the reference count. The new page must be closed independently
* of the original page.
*
* @return the new page
*/
public Page duplicate() {
refCountedCloseable.incRef();
return new Page(byteBuffer.duplicate(), refCountedCloseable);
}
/**
* Returns the {@link ByteBuffer} for this page. Modifications to the limits, positions, etc of the
* buffer will also mutate this page. Call {@link ByteBuffer#duplicate()} to avoid mutating the page.
*
* @return the byte buffer
*/
public ByteBuffer byteBuffer() {
return byteBuffer;
}
@Override
public void close() {
refCountedCloseable.decRef();
}
private static class RefCountedCloseable extends AbstractRefCounted {
private final Runnable closeable;
private RefCountedCloseable(Runnable closeable) {
super("byte array page");
this.closeable = closeable;
}
@Override
protected void closeInternal() {
closeable.run();
}
}
}

View File

@ -325,7 +325,7 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
ioBuffer.clear();
ioBuffer.limit(Math.min(WRITE_LIMIT, ioBuffer.limit()));
int j = 0;
ByteBuffer[] buffers = flushOperation.getBuffersToWrite(WRITE_LIMIT);
ByteBuffer[] buffers = flushOperation.getBuffersToWrite();
while (j < buffers.length && ioBuffer.remaining() > 0) {
ByteBuffer buffer = buffers[j++];
copyBytes(buffer, ioBuffer);

View File

@ -31,7 +31,6 @@ import java.util.function.BiConsumer;
import java.util.function.Consumer;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
@ -169,7 +168,7 @@ public class BytesChannelContextTests extends ESTestCase {
assertTrue(context.readyForFlush());
when(flushOperation.getBuffersToWrite(anyInt())).thenReturn(buffers);
when(flushOperation.getBuffersToWrite()).thenReturn(buffers);
when(flushOperation.isFullyFlushed()).thenReturn(false, true);
when(flushOperation.getListener()).thenReturn(listener);
context.flushChannel();
@ -188,7 +187,7 @@ public class BytesChannelContextTests extends ESTestCase {
assertTrue(context.readyForFlush());
when(flushOperation.isFullyFlushed()).thenReturn(false);
when(flushOperation.getBuffersToWrite(anyInt())).thenReturn(new ByteBuffer[] {ByteBuffer.allocate(3)});
when(flushOperation.getBuffersToWrite()).thenReturn(new ByteBuffer[] {ByteBuffer.allocate(3)});
context.flushChannel();
verify(listener, times(0)).accept(null, null);
@ -202,8 +201,8 @@ public class BytesChannelContextTests extends ESTestCase {
BiConsumer<Void, Exception> listener2 = mock(BiConsumer.class);
FlushReadyWrite flushOperation1 = mock(FlushReadyWrite.class);
FlushReadyWrite flushOperation2 = mock(FlushReadyWrite.class);
when(flushOperation1.getBuffersToWrite(anyInt())).thenReturn(new ByteBuffer[] {ByteBuffer.allocate(3)});
when(flushOperation2.getBuffersToWrite(anyInt())).thenReturn(new ByteBuffer[] {ByteBuffer.allocate(3)});
when(flushOperation1.getBuffersToWrite()).thenReturn(new ByteBuffer[] {ByteBuffer.allocate(3)});
when(flushOperation2.getBuffersToWrite()).thenReturn(new ByteBuffer[] {ByteBuffer.allocate(3)});
when(flushOperation1.getListener()).thenReturn(listener);
when(flushOperation2.getListener()).thenReturn(listener2);
@ -238,7 +237,7 @@ public class BytesChannelContextTests extends ESTestCase {
assertTrue(context.readyForFlush());
IOException exception = new IOException();
when(flushOperation.getBuffersToWrite(anyInt())).thenReturn(buffers);
when(flushOperation.getBuffersToWrite()).thenReturn(buffers);
when(rawChannel.write(any(ByteBuffer.class))).thenThrow(exception);
when(flushOperation.getListener()).thenReturn(listener);
expectThrows(IOException.class, () -> context.flushChannel());
@ -253,7 +252,7 @@ public class BytesChannelContextTests extends ESTestCase {
context.queueWriteOperation(flushOperation);
IOException exception = new IOException();
when(flushOperation.getBuffersToWrite(anyInt())).thenReturn(buffers);
when(flushOperation.getBuffersToWrite()).thenReturn(buffers);
when(rawChannel.write(any(ByteBuffer.class))).thenThrow(exception);
assertFalse(context.selectorShouldClose());

View File

@ -65,45 +65,29 @@ public class FlushOperationTests extends ESTestCase {
ByteBuffer[] byteBuffers = writeOp.getBuffersToWrite();
assertEquals(3, byteBuffers.length);
assertEquals(5, byteBuffers[0].remaining());
ByteBuffer[] byteBuffersWithLimit = writeOp.getBuffersToWrite(10);
assertEquals(2, byteBuffersWithLimit.length);
assertEquals(5, byteBuffersWithLimit[0].remaining());
assertEquals(5, byteBuffersWithLimit[1].remaining());
writeOp.incrementIndex(5);
assertFalse(writeOp.isFullyFlushed());
byteBuffers = writeOp.getBuffersToWrite();
assertEquals(2, byteBuffers.length);
assertEquals(15, byteBuffers[0].remaining());
assertEquals(3, byteBuffers[1].remaining());
byteBuffersWithLimit = writeOp.getBuffersToWrite(10);
assertEquals(1, byteBuffersWithLimit.length);
assertEquals(10, byteBuffersWithLimit[0].remaining());
writeOp.incrementIndex(2);
assertFalse(writeOp.isFullyFlushed());
byteBuffers = writeOp.getBuffersToWrite();
assertEquals(2, byteBuffers.length);
assertEquals(13, byteBuffers[0].remaining());
assertEquals(3, byteBuffers[1].remaining());
byteBuffersWithLimit = writeOp.getBuffersToWrite(10);
assertEquals(1, byteBuffersWithLimit.length);
assertEquals(10, byteBuffersWithLimit[0].remaining());
writeOp.incrementIndex(15);
assertFalse(writeOp.isFullyFlushed());
byteBuffers = writeOp.getBuffersToWrite();
assertEquals(1, byteBuffers.length);
assertEquals(1, byteBuffers[0].remaining());
byteBuffersWithLimit = writeOp.getBuffersToWrite(10);
assertEquals(1, byteBuffersWithLimit.length);
assertEquals(1, byteBuffersWithLimit[0].remaining());
writeOp.incrementIndex(1);
assertTrue(writeOp.isFullyFlushed());
byteBuffers = writeOp.getBuffersToWrite();
assertEquals(0, byteBuffers.length);
byteBuffersWithLimit = writeOp.getBuffersToWrite(10);
assertEquals(0, byteBuffersWithLimit.length);
assertEquals(1, byteBuffers.length);
assertEquals(0, byteBuffers[0].remaining());
}
}

View File

@ -30,8 +30,8 @@ import java.util.function.Supplier;
public class InboundChannelBufferTests extends ESTestCase {
private static final int PAGE_SIZE = PageCacheRecycler.PAGE_SIZE_IN_BYTES;
private final Supplier<Page> defaultPageSupplier = () ->
new Page(ByteBuffer.allocate(PageCacheRecycler.BYTE_PAGE_SIZE), () -> {
private final Supplier<InboundChannelBuffer.Page> defaultPageSupplier = () ->
new InboundChannelBuffer.Page(ByteBuffer.allocate(PageCacheRecycler.BYTE_PAGE_SIZE), () -> {
});
public void testNewBufferNoPages() {
@ -126,10 +126,10 @@ public class InboundChannelBufferTests extends ESTestCase {
public void testReleaseClosesPages() {
ConcurrentLinkedQueue<AtomicBoolean> queue = new ConcurrentLinkedQueue<>();
Supplier<Page> supplier = () -> {
Supplier<InboundChannelBuffer.Page> supplier = () -> {
AtomicBoolean atomicBoolean = new AtomicBoolean();
queue.add(atomicBoolean);
return new Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true));
return new InboundChannelBuffer.Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true));
};
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(supplier);
channelBuffer.ensureCapacity(PAGE_SIZE * 4);
@ -153,10 +153,10 @@ public class InboundChannelBufferTests extends ESTestCase {
public void testClose() {
ConcurrentLinkedQueue<AtomicBoolean> queue = new ConcurrentLinkedQueue<>();
Supplier<Page> supplier = () -> {
Supplier<InboundChannelBuffer.Page> supplier = () -> {
AtomicBoolean atomicBoolean = new AtomicBoolean();
queue.add(atomicBoolean);
return new Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true));
return new InboundChannelBuffer.Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true));
};
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(supplier);
channelBuffer.ensureCapacity(PAGE_SIZE * 4);
@ -178,10 +178,10 @@ public class InboundChannelBufferTests extends ESTestCase {
public void testCloseRetainedPages() {
ConcurrentLinkedQueue<AtomicBoolean> queue = new ConcurrentLinkedQueue<>();
Supplier<Page> supplier = () -> {
Supplier<InboundChannelBuffer.Page> supplier = () -> {
AtomicBoolean atomicBoolean = new AtomicBoolean();
queue.add(atomicBoolean);
return new Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true));
return new InboundChannelBuffer.Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true));
};
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(supplier);
channelBuffer.ensureCapacity(PAGE_SIZE * 4);
@ -192,7 +192,7 @@ public class InboundChannelBufferTests extends ESTestCase {
assertFalse(closedRef.get());
}
Page[] pages = channelBuffer.sliceAndRetainPagesTo(PAGE_SIZE * 2);
InboundChannelBuffer.Page[] pages = channelBuffer.sliceAndRetainPagesTo(PAGE_SIZE * 2);
pages[1].close();

View File

@ -285,7 +285,7 @@ public class SocketChannelContextTests extends ESTestCase {
when(channel.getRawChannel()).thenReturn(realChannel);
when(channel.isOpen()).thenReturn(true);
Runnable closer = mock(Runnable.class);
Supplier<Page> pageSupplier = () -> new Page(ByteBuffer.allocate(1 << 14), closer);
Supplier<InboundChannelBuffer.Page> pageSupplier = () -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14), closer);
InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier);
buffer.ensureCapacity(1);
TestSocketChannelContext context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, buffer);

View File

@ -29,7 +29,7 @@ import io.netty.channel.ChannelPromise;
import io.netty.channel.embedded.EmbeddedChannel;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.nio.FlushOperation;
import org.elasticsearch.nio.Page;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.WriteOperation;
import java.nio.ByteBuffer;
@ -97,7 +97,7 @@ class NettyAdaptor {
return byteBuf.readerIndex() - initialReaderIndex;
}
public int read(Page[] pages) {
public int read(InboundChannelBuffer.Page[] pages) {
ByteBuf byteBuf = PagedByteBuf.byteBufFromPages(pages);
int readableBytes = byteBuf.readableBytes();
nettyChannel.writeInbound(byteBuf);

View File

@ -43,7 +43,6 @@ import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioGroup;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.Page;
import org.elasticsearch.nio.ServerChannelContext;
import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.rest.RestUtils;
@ -206,9 +205,9 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport {
@Override
public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException {
NioHttpChannel httpChannel = new NioHttpChannel(channel);
java.util.function.Supplier<Page> pageSupplier = () -> {
java.util.function.Supplier<InboundChannelBuffer.Page> pageSupplier = () -> {
Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
return new Page(ByteBuffer.wrap(bytes.v()), bytes::close);
return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close);
};
HttpReadWriteHandler httpReadWritePipeline = new HttpReadWriteHandler(httpChannel,NioHttpServerTransport.this,
handlingSettings, corsConfig);

View File

@ -24,7 +24,7 @@ import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.buffer.UnpooledHeapByteBuf;
import org.elasticsearch.nio.Page;
import org.elasticsearch.nio.InboundChannelBuffer;
import java.nio.ByteBuffer;
import java.util.ArrayList;
@ -39,7 +39,7 @@ public class PagedByteBuf extends UnpooledHeapByteBuf {
this.releasable = releasable;
}
static ByteBuf byteBufFromPages(Page[] pages) {
static ByteBuf byteBufFromPages(InboundChannelBuffer.Page[] pages) {
int componentCount = pages.length;
if (componentCount == 0) {
return Unpooled.EMPTY_BUFFER;
@ -48,15 +48,15 @@ public class PagedByteBuf extends UnpooledHeapByteBuf {
} else {
int maxComponents = Math.max(16, componentCount);
final List<ByteBuf> components = new ArrayList<>(componentCount);
for (Page page : pages) {
for (InboundChannelBuffer.Page page : pages) {
components.add(byteBufFromPage(page));
}
return new CompositeByteBuf(UnpooledByteBufAllocator.DEFAULT, false, maxComponents, components);
}
}
private static ByteBuf byteBufFromPage(Page page) {
ByteBuffer buffer = page.byteBuffer();
private static ByteBuf byteBufFromPage(InboundChannelBuffer.Page page) {
ByteBuffer buffer = page.getByteBuffer();
assert buffer.isDirect() == false && buffer.hasArray() : "Must be a heap buffer with an array";
int offset = buffer.arrayOffset() + buffer.position();
PagedByteBuf newByteBuf = new PagedByteBuf(buffer.array(), page::close);

View File

@ -36,7 +36,6 @@ import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioGroup;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.Page;
import org.elasticsearch.nio.ServerChannelContext;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TcpTransport;
@ -158,9 +157,9 @@ public class NioTransport extends TcpTransport {
@Override
public NioTcpChannel createChannel(NioSelector selector, SocketChannel channel) {
NioTcpChannel nioChannel = new NioTcpChannel(isClient == false, profileName, channel);
Supplier<Page> pageSupplier = () -> {
Supplier<InboundChannelBuffer.Page> pageSupplier = () -> {
Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
return new Page(ByteBuffer.wrap(bytes.v()), bytes::close);
return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close);
};
TcpReadWriteHandler readWriteHandler = new TcpReadWriteHandler(nioChannel, NioTransport.this);
Consumer<Exception> exceptionHandler = (e) -> onException(nioChannel, e);

View File

@ -20,7 +20,7 @@
package org.elasticsearch.http.nio;
import io.netty.buffer.ByteBuf;
import org.elasticsearch.nio.Page;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.test.ESTestCase;
import java.nio.ByteBuffer;
@ -32,12 +32,12 @@ public class PagedByteBufTests extends ESTestCase {
public void testReleasingPage() {
AtomicInteger integer = new AtomicInteger(0);
int pageCount = randomInt(10) + 1;
ArrayList<Page> pages = new ArrayList<>();
ArrayList<InboundChannelBuffer.Page> pages = new ArrayList<>();
for (int i = 0; i < pageCount; ++i) {
pages.add(new Page(ByteBuffer.allocate(10), integer::incrementAndGet));
pages.add(new InboundChannelBuffer.Page(ByteBuffer.allocate(10), integer::incrementAndGet));
}
ByteBuf byteBuf = PagedByteBuf.byteBufFromPages(pages.toArray(new Page[0]));
ByteBuf byteBuf = PagedByteBuf.byteBufFromPages(pages.toArray(new InboundChannelBuffer.Page[0]));
assertEquals(0, integer.get());
byteBuf.retain();
@ -62,9 +62,9 @@ public class PagedByteBufTests extends ESTestCase {
bytes2[i - 10] = (byte) i;
}
Page[] pages = new Page[2];
pages[0] = new Page(ByteBuffer.wrap(bytes1), () -> {});
pages[1] = new Page(ByteBuffer.wrap(bytes2), () -> {});
InboundChannelBuffer.Page[] pages = new InboundChannelBuffer.Page[2];
pages[0] = new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes1), () -> {});
pages[1] = new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes2), () -> {});
ByteBuf byteBuf = PagedByteBuf.byteBufFromPages(pages);
assertEquals(20, byteBuf.readableBytes());
@ -73,13 +73,13 @@ public class PagedByteBufTests extends ESTestCase {
assertEquals((byte) i, byteBuf.getByte(i));
}
Page[] pages2 = new Page[2];
InboundChannelBuffer.Page[] pages2 = new InboundChannelBuffer.Page[2];
ByteBuffer firstBuffer = ByteBuffer.wrap(bytes1);
firstBuffer.position(2);
ByteBuffer secondBuffer = ByteBuffer.wrap(bytes2);
secondBuffer.limit(8);
pages2[0] = new Page(firstBuffer, () -> {});
pages2[1] = new Page(secondBuffer, () -> {});
pages2[0] = new InboundChannelBuffer.Page(firstBuffer, () -> {});
pages2[1] = new InboundChannelBuffer.Page(secondBuffer, () -> {});
ByteBuf byteBuf2 = PagedByteBuf.byteBufFromPages(pages2);
assertEquals(16, byteBuf2.readableBytes());

View File

@ -41,7 +41,6 @@ import org.elasticsearch.nio.NioSelectorGroup;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioServerSocketChannel;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.Page;
import org.elasticsearch.nio.ServerChannelContext;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.ConnectionProfile;
@ -192,9 +191,9 @@ public class MockNioTransport extends TcpTransport {
@Override
public MockSocketChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException {
MockSocketChannel nioChannel = new MockSocketChannel(isClient == false, profileName, channel);
Supplier<Page> pageSupplier = () -> {
Supplier<InboundChannelBuffer.Page> pageSupplier = () -> {
Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
return new Page(ByteBuffer.wrap(bytes.v()), bytes::close);
return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close);
};
MockTcpReadWriteHandler readWriteHandler = new MockTcpReadWriteHandler(nioChannel, MockNioTransport.this);
BytesChannelContext context = new BytesChannelContext(nioChannel, selector, (e) -> exceptionCaught(nioChannel, e),

View File

@ -10,7 +10,6 @@ import org.elasticsearch.core.internal.io.IOUtils;
import org.elasticsearch.nio.FlushOperation;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.Page;
import org.elasticsearch.nio.ReadWriteHandler;
import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.nio.NioSelector;
@ -18,8 +17,6 @@ import org.elasticsearch.nio.WriteOperation;
import javax.net.ssl.SSLEngine;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.util.concurrent.TimeUnit;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
@ -37,8 +34,6 @@ public final class SSLChannelContext extends SocketChannelContext {
private static final Runnable DEFAULT_TIMEOUT_CANCELLER = () -> {};
private final SSLDriver sslDriver;
private final SSLOutboundBuffer outboundBuffer;
private FlushOperation encryptedFlush;
private Runnable closeTimeoutCanceller = DEFAULT_TIMEOUT_CANCELLER;
SSLChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler, SSLDriver sslDriver,
@ -51,8 +46,6 @@ public final class SSLChannelContext extends SocketChannelContext {
Predicate<NioSocketChannel> allowChannelPredicate) {
super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer, allowChannelPredicate);
this.sslDriver = sslDriver;
// TODO: When the bytes are actually recycled, we need to test that they are released on context close
this.outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n)));
}
@Override
@ -79,32 +72,34 @@ public final class SSLChannelContext extends SocketChannelContext {
return;
}
// If there is currently data in the outbound write buffer, flush the buffer.
if (pendingChannelFlush()) {
if (sslDriver.hasFlushPending()) {
// If the data is not completely flushed, exit. We cannot produce new write data until the
// existing data has been fully flushed.
flushEncryptedOperation();
if (pendingChannelFlush()) {
flushToChannel(sslDriver.getNetworkWriteBuffer());
if (sslDriver.hasFlushPending()) {
return;
}
}
// If the driver is ready for application writes, we can attempt to proceed with any queued writes.
if (sslDriver.readyForApplicationWrites()) {
FlushOperation unencryptedFlush;
while (pendingChannelFlush() == false && (unencryptedFlush = getPendingFlush()) != null) {
if (unencryptedFlush.isFullyFlushed()) {
FlushOperation currentFlush;
while (sslDriver.hasFlushPending() == false && (currentFlush = getPendingFlush()) != null) {
// If the current operation has been fully consumed (encrypted) we now know that it has been
// sent (as we only get to this point if the write buffer has been fully flushed).
if (currentFlush.isFullyFlushed()) {
currentFlushOperationComplete();
} else {
try {
// Attempt to encrypt application write data. The encrypted data ends up in the
// outbound write buffer.
sslDriver.write(unencryptedFlush, outboundBuffer);
if (outboundBuffer.hasEncryptedBytesToFlush() == false) {
int bytesEncrypted = sslDriver.applicationWrite(currentFlush.getBuffersToWrite());
if (bytesEncrypted == 0) {
break;
}
encryptedFlush = outboundBuffer.buildNetworkFlushOperation();
currentFlush.incrementIndex(bytesEncrypted);
// Flush the write buffer to the channel
flushEncryptedOperation();
flushToChannel(sslDriver.getNetworkWriteBuffer());
} catch (IOException e) {
currentFlushOperationFailed(e);
throw e;
@ -114,38 +109,23 @@ public final class SSLChannelContext extends SocketChannelContext {
} else {
// We are not ready for application writes, check if the driver has non-application writes. We
// only want to continue producing new writes if the outbound write buffer is fully flushed.
while (pendingChannelFlush() == false && sslDriver.needsNonApplicationWrite()) {
sslDriver.nonApplicationWrite(outboundBuffer);
while (sslDriver.hasFlushPending() == false && sslDriver.needsNonApplicationWrite()) {
sslDriver.nonApplicationWrite();
// If non-application writes were produced, flush the outbound write buffer.
if (outboundBuffer.hasEncryptedBytesToFlush()) {
encryptedFlush = outboundBuffer.buildNetworkFlushOperation();
flushEncryptedOperation();
if (sslDriver.hasFlushPending()) {
flushToChannel(sslDriver.getNetworkWriteBuffer());
}
}
}
}
private void flushEncryptedOperation() throws IOException {
try {
flushToChannel(encryptedFlush);
if (encryptedFlush.isFullyFlushed()) {
getSelector().executeListener(encryptedFlush.getListener(), null);
encryptedFlush = null;
}
} catch (IOException e) {
getSelector().executeFailedListener(encryptedFlush.getListener(), e);
encryptedFlush = null;
throw e;
}
}
@Override
public boolean readyForFlush() {
getSelector().assertOnSelectorThread();
if (sslDriver.readyForApplicationWrites()) {
return pendingChannelFlush() || super.readyForFlush();
return sslDriver.hasFlushPending() || super.readyForFlush();
} else {
return pendingChannelFlush() || sslDriver.needsNonApplicationWrite();
return sslDriver.hasFlushPending() || sslDriver.needsNonApplicationWrite();
}
}
@ -169,7 +149,7 @@ public final class SSLChannelContext extends SocketChannelContext {
@Override
public boolean selectorShouldClose() {
return closeNow() || (sslDriver.isClosed() && pendingChannelFlush() == false);
return closeNow() || sslDriver.isClosed();
}
@Override
@ -190,10 +170,7 @@ public final class SSLChannelContext extends SocketChannelContext {
getSelector().assertOnSelectorThread();
if (channel.isOpen()) {
closeTimeoutCanceller.run();
if (encryptedFlush != null) {
getSelector().executeFailedListener(encryptedFlush.getListener(), new ClosedChannelException());
}
IOUtils.close(super::closeFromSelector, outboundBuffer::close, sslDriver::close);
IOUtils.close(super::closeFromSelector, sslDriver::close);
}
}
@ -207,14 +184,9 @@ public final class SSLChannelContext extends SocketChannelContext {
getSelector().queueChannelClose(channel);
}
private boolean pendingChannelFlush() {
return encryptedFlush != null;
}
private static class CloseNotifyOperation implements WriteOperation {
private static final BiConsumer<Void, Exception> LISTENER = (v, t) -> {
};
private static final BiConsumer<Void, Exception> LISTENER = (v, t) -> {};
private static final Object WRITE_OBJECT = new Object();
private final SocketChannelContext channelContext;

View File

@ -5,7 +5,6 @@
*/
package org.elasticsearch.xpack.security.transport.nio;
import org.elasticsearch.nio.FlushOperation;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.utils.ExceptionsHelper;
@ -30,17 +29,19 @@ import java.util.ArrayList;
* the buffer passed as an argument. Otherwise, it will be consumed internally and advance the SSL/TLS close
* or handshake process.
*
* Producing writes for a channel is more complicated. The method {@link #needsNonApplicationWrite()} can be
* called to determine if this driver needs to produce more data to advance the handshake or close process.
* If that method returns true, {@link #nonApplicationWrite(SSLOutboundBuffer)} should be called (and the
* data produced then flushed to the channel) until no further non-application writes are needed.
* Producing writes for a channel is more complicated. If there is existing data in the outbound write buffer
* as indicated by {@link #hasFlushPending()}, that data must be written to the channel before more outbound
* data can be produced. If no flushes are pending, {@link #needsNonApplicationWrite()} can be called to
* determine if this driver needs to produce more data to advance the handshake or close process. If that
* method returns true, {@link #nonApplicationWrite()} should be called (and the data produced then flushed
* to the channel) until no further non-application writes are needed.
*
* If no non-application writes are needed, {@link #readyForApplicationWrites()} can be called to determine
* if the driver is ready to consume application data. (Note: It is possible that
* {@link #readyForApplicationWrites()} and {@link #needsNonApplicationWrite()} can both return false if the
* driver is waiting on non-application data from the peer.) If the driver indicates it is ready for
* application writes, {@link #write(FlushOperation, SSLOutboundBuffer)} can be called. This method will
* encrypt flush operation application data and place it in the outbound buffer for flushing to a channel.
* application writes, {@link #applicationWrite(ByteBuffer[])} can be called. This method will encrypt
* application data and place it in the write buffer for flushing to a channel.
*
* If you are ready to close the channel {@link #initiateClose()} should be called. After that is called, the
* driver will start producing non-application writes related to notifying the peer connection that this
@ -49,23 +50,23 @@ import java.util.ArrayList;
*/
public class SSLDriver implements AutoCloseable {
private static final ByteBuffer[] EMPTY_BUFFERS = {ByteBuffer.allocate(0)};
private static final FlushOperation EMPTY_FLUSH_OPERATION = new FlushOperation(EMPTY_BUFFERS, (r, t) -> {});
private static final ByteBuffer[] EMPTY_BUFFER_ARRAY = new ByteBuffer[0];
private final SSLEngine engine;
private final boolean isClientMode;
// This should only be accessed by the network thread associated with this channel, so nothing needs to
// be volatile.
private Mode currentMode = new HandshakeMode();
private ByteBuffer networkWriteBuffer;
private ByteBuffer networkReadBuffer;
private int packetSize;
public SSLDriver(SSLEngine engine, boolean isClientMode) {
this.engine = engine;
this.isClientMode = isClientMode;
SSLSession session = engine.getSession();
packetSize = session.getPacketBufferSize();
this.networkReadBuffer = ByteBuffer.allocate(packetSize);
this.networkReadBuffer = ByteBuffer.allocate(session.getPacketBufferSize());
this.networkWriteBuffer = ByteBuffer.allocate(session.getPacketBufferSize());
this.networkWriteBuffer.position(this.networkWriteBuffer.limit());
}
public void init() throws SSLException {
@ -99,10 +100,18 @@ public class SSLDriver implements AutoCloseable {
return engine;
}
public boolean hasFlushPending() {
return networkWriteBuffer.hasRemaining();
}
public boolean isHandshaking() {
return currentMode.isHandshake();
}
public ByteBuffer getNetworkWriteBuffer() {
return networkWriteBuffer;
}
public ByteBuffer getNetworkReadBuffer() {
return networkReadBuffer;
}
@ -125,14 +134,15 @@ public class SSLDriver implements AutoCloseable {
return currentMode.needsNonApplicationWrite();
}
public int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException {
return currentMode.write(applicationBytes, outboundBuffer);
public int applicationWrite(ByteBuffer[] buffers) throws SSLException {
assert readyForApplicationWrites() : "Should not be called if driver is not ready for application writes";
return currentMode.write(buffers);
}
public void nonApplicationWrite(SSLOutboundBuffer outboundBuffer) throws SSLException {
public void nonApplicationWrite() throws SSLException {
assert currentMode.isApplication() == false : "Should not be called if driver is in application mode";
if (currentMode.isApplication() == false) {
currentMode.write(EMPTY_FLUSH_OPERATION, outboundBuffer);
currentMode.write(EMPTY_BUFFER_ARRAY);
} else {
throw new AssertionError("Attempted to non-application write from invalid mode: " + currentMode.modeName());
}
@ -195,36 +205,45 @@ public class SSLDriver implements AutoCloseable {
}
}
private SSLEngineResult wrap(SSLOutboundBuffer outboundBuffer) throws SSLException {
return wrap(outboundBuffer, EMPTY_FLUSH_OPERATION);
}
private SSLEngineResult wrap(ByteBuffer[] buffers) throws SSLException {
assert hasFlushPending() == false : "Should never called with pending writes";
private SSLEngineResult wrap(SSLOutboundBuffer outboundBuffer, FlushOperation applicationBytes) throws SSLException {
ByteBuffer[] buffers = applicationBytes.getBuffersToWrite(engine.getSession().getApplicationBufferSize());
networkWriteBuffer.clear();
while (true) {
SSLEngineResult result;
ByteBuffer networkBuffer = outboundBuffer.nextWriteBuffer(packetSize);
try {
result = engine.wrap(buffers, networkBuffer);
if (buffers.length == 1) {
result = engine.wrap(buffers[0], networkWriteBuffer);
} else {
result = engine.wrap(buffers, networkWriteBuffer);
}
} catch (SSLException e) {
outboundBuffer.incrementEncryptedBytes(0);
networkWriteBuffer.position(networkWriteBuffer.limit());
throw e;
}
outboundBuffer.incrementEncryptedBytes(result.bytesProduced());
applicationBytes.incrementIndex(result.bytesConsumed());
switch (result.getStatus()) {
case OK:
networkWriteBuffer.flip();
return result;
case BUFFER_UNDERFLOW:
throw new IllegalStateException("Should not receive BUFFER_UNDERFLOW on WRAP");
case BUFFER_OVERFLOW:
packetSize = engine.getSession().getPacketBufferSize();
// There is not enough space in the network buffer for an entire SSL packet. We will
// allocate a buffer with the correct packet size the next time through the loop.
// There is not enough space in the network buffer for an entire SSL packet. Expand the
// buffer if it's smaller than the current session packet size. Otherwise return and wait
// for existing data to be flushed.
int currentCapacity = networkWriteBuffer.capacity();
ensureNetworkWriteBufferSize();
if (currentCapacity == networkWriteBuffer.capacity()) {
return result;
}
break;
case CLOSED:
assert result.bytesProduced() > 0 : "WRAP during close processing should produce close message.";
if (result.bytesProduced() > 0) {
networkWriteBuffer.flip();
} else {
assert false : "WRAP during close processing should produce close message.";
}
return result;
default:
throw new IllegalStateException("Unexpected WRAP result: " + result.getStatus());
@ -246,12 +265,23 @@ public class SSLDriver implements AutoCloseable {
}
}
private void ensureNetworkWriteBufferSize() {
networkWriteBuffer = ensureNetBufferSize(networkWriteBuffer);
}
private void ensureNetworkReadBufferSize() {
packetSize = engine.getSession().getPacketBufferSize();
if (networkReadBuffer.capacity() < packetSize) {
ByteBuffer newBuffer = ByteBuffer.allocate(packetSize);
networkReadBuffer.flip();
newBuffer.put(networkReadBuffer);
networkReadBuffer = ensureNetBufferSize(networkReadBuffer);
}
private ByteBuffer ensureNetBufferSize(ByteBuffer current) {
int networkPacketSize = engine.getSession().getPacketBufferSize();
if (current.capacity() < networkPacketSize) {
ByteBuffer newBuffer = ByteBuffer.allocate(networkPacketSize);
current.flip();
newBuffer.put(current);
return newBuffer;
} else {
return current;
}
}
@ -276,7 +306,7 @@ public class SSLDriver implements AutoCloseable {
void read(InboundChannelBuffer buffer) throws SSLException;
int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException;
int write(ByteBuffer[] buffers) throws SSLException;
boolean needsNonApplicationWrite();
@ -299,7 +329,7 @@ public class SSLDriver implements AutoCloseable {
if (handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP &&
handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_WRAP) {
try {
handshake(null);
handshake();
} catch (SSLException e) {
closingInternal();
throw e;
@ -307,7 +337,7 @@ public class SSLDriver implements AutoCloseable {
}
}
private void handshake(SSLOutboundBuffer outboundBuffer) throws SSLException {
private void handshake() throws SSLException {
boolean continueHandshaking = true;
while (continueHandshaking) {
switch (handshakeStatus) {
@ -316,13 +346,11 @@ public class SSLDriver implements AutoCloseable {
continueHandshaking = false;
break;
case NEED_WRAP:
if (outboundBuffer != null) {
handshakeStatus = wrap(outboundBuffer).getHandshakeStatus();
// If we need NEED_TASK we should run the tasks immediately
if (handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_TASK) {
continueHandshaking = false;
}
} else {
if (hasFlushPending() == false) {
handshakeStatus = wrap(EMPTY_BUFFER_ARRAY).getHandshakeStatus();
}
// If we need NEED_TASK we should run the tasks immediately
if (handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_TASK) {
continueHandshaking = false;
}
break;
@ -351,7 +379,7 @@ public class SSLDriver implements AutoCloseable {
try {
SSLEngineResult result = unwrap(buffer);
handshakeStatus = result.getHandshakeStatus();
handshake(null);
handshake();
// If we are done handshaking we should exit the handshake read
continueUnwrap = result.bytesConsumed() > 0 && currentMode.isHandshake();
} catch (SSLException e) {
@ -362,9 +390,9 @@ public class SSLDriver implements AutoCloseable {
}
@Override
public int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException {
public int write(ByteBuffer[] buffers) throws SSLException {
try {
handshake(outboundBuffer);
handshake();
} catch (SSLException e) {
closingInternal();
throw e;
@ -417,7 +445,8 @@ public class SSLDriver implements AutoCloseable {
String message = "Expected to be in handshaking/closed mode. Instead in application mode.";
throw new AssertionError(message);
}
} else {
} else if (hasFlushPending() == false) {
// We only acknowledge that we are done handshaking if there are no bytes that need to be written
if (currentMode.isHandshake()) {
currentMode = new ApplicationMode();
} else {
@ -444,17 +473,10 @@ public class SSLDriver implements AutoCloseable {
}
@Override
public int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException {
boolean continueWrap = true;
int totalBytesProduced = 0;
while (continueWrap && applicationBytes.isFullyFlushed() == false) {
SSLEngineResult result = wrap(outboundBuffer, applicationBytes);
int bytesProduced = result.bytesProduced();
totalBytesProduced += bytesProduced;
boolean renegotiationRequested = maybeRenegotiation(result.getHandshakeStatus());
continueWrap = bytesProduced > 0 && renegotiationRequested == false;
}
return totalBytesProduced;
public int write(ByteBuffer[] buffers) throws SSLException {
SSLEngineResult result = wrap(buffers);
maybeRenegotiation(result.getHandshakeStatus());
return result.bytesConsumed();
}
private boolean maybeRenegotiation(SSLEngineResult.HandshakeStatus newStatus) throws SSLException {
@ -538,19 +560,18 @@ public class SSLDriver implements AutoCloseable {
}
@Override
public int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException {
int bytesProduced = 0;
if (engine.isOutboundDone() == false) {
bytesProduced += wrap(outboundBuffer).bytesProduced();
if (engine.isOutboundDone()) {
needToSendClose = false;
// Close inbound if it is still open and we have decided not to wait for response.
if (needToReceiveClose == false && engine.isInboundDone() == false) {
closeInboundAndSwallowPeerDidNotCloseException();
}
public int write(ByteBuffer[] buffers) throws SSLException {
if (hasFlushPending() == false && engine.isOutboundDone()) {
needToSendClose = false;
// Close inbound if it is still open and we have decided not to wait for response.
if (needToReceiveClose == false && engine.isInboundDone() == false) {
closeInboundAndSwallowPeerDidNotCloseException();
}
} else {
wrap(EMPTY_BUFFER_ARRAY);
assert hasFlushPending() : "Should have produced close message";
}
return bytesProduced;
return 0;
}
@Override

View File

@ -1,68 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.security.transport.nio;
import org.elasticsearch.core.internal.io.IOUtils;
import org.elasticsearch.nio.FlushOperation;
import org.elasticsearch.nio.Page;
import java.nio.ByteBuffer;
import java.util.ArrayDeque;
import java.util.function.IntFunction;
public class SSLOutboundBuffer implements AutoCloseable {
private final ArrayDeque<Page> pages;
private final IntFunction<Page> pageSupplier;
private Page currentPage;
SSLOutboundBuffer(IntFunction<Page> pageSupplier) {
this.pages = new ArrayDeque<>();
this.pageSupplier = pageSupplier;
}
void incrementEncryptedBytes(int encryptedBytesProduced) {
if (encryptedBytesProduced != 0) {
currentPage.byteBuffer().limit(encryptedBytesProduced);
pages.addLast(currentPage);
}
currentPage = null;
}
ByteBuffer nextWriteBuffer(int networkBufferSize) {
if (currentPage != null) {
// If there is an existing page, close it as it wasn't large enough to accommodate the SSLEngine.
currentPage.close();
}
Page newPage = pageSupplier.apply(networkBufferSize);
currentPage = newPage;
return newPage.byteBuffer().duplicate();
}
FlushOperation buildNetworkFlushOperation() {
int pageCount = pages.size();
ByteBuffer[] byteBuffers = new ByteBuffer[pageCount];
Page[] pagesToClose = new Page[pageCount];
for (int i = 0; i < pageCount; ++i) {
Page page = pages.removeFirst();
pagesToClose[i] = page;
byteBuffers[i] = page.byteBuffer();
}
return new FlushOperation(byteBuffers, (r, e) -> IOUtils.closeWhileHandlingException(pagesToClose));
}
boolean hasEncryptedBytesToFlush() {
return pages.isEmpty() == false;
}
@Override
public void close() {
IOUtils.closeWhileHandlingException(pages);
}
}

View File

@ -22,7 +22,6 @@ import org.elasticsearch.nio.ChannelFactory;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.Page;
import org.elasticsearch.nio.ServerChannelContext;
import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.threadpool.ThreadPool;
@ -93,9 +92,9 @@ public class SecurityNioHttpServerTransport extends NioHttpServerTransport {
@Override
public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException {
NioHttpChannel httpChannel = new NioHttpChannel(channel);
Supplier<Page> pageSupplier = () -> {
Supplier<InboundChannelBuffer.Page> pageSupplier = () -> {
Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
return new Page(ByteBuffer.wrap(bytes.v()), bytes::close);
return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close);
};
HttpReadWriteHandler httpHandler = new HttpReadWriteHandler(httpChannel,SecurityNioHttpServerTransport.this,
handlingSettings, corsConfig);

View File

@ -21,7 +21,6 @@ import org.elasticsearch.nio.ChannelFactory;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.Page;
import org.elasticsearch.nio.ServerChannelContext;
import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.threadpool.ThreadPool;
@ -156,9 +155,9 @@ public class SecurityNioTransport extends NioTransport {
@Override
public NioTcpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException {
NioTcpChannel nioChannel = new NioTcpChannel(isClient == false, profileName, channel);
Supplier<Page> pageSupplier = () -> {
Supplier<InboundChannelBuffer.Page> pageSupplier = () -> {
Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
return new Page(ByteBuffer.wrap(bytes.v()), bytes::close);
return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close);
};
TcpReadWriteHandler readWriteHandler = new TcpReadWriteHandler(nioChannel, SecurityNioTransport.this);
InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier);

View File

@ -8,7 +8,6 @@ package org.elasticsearch.xpack.security.transport.nio;
import org.elasticsearch.common.CheckedFunction;
import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.nio.BytesWriteHandler;
import org.elasticsearch.nio.FlushOperation;
import org.elasticsearch.nio.FlushReadyWrite;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSelector;
@ -29,7 +28,6 @@ import java.util.function.Consumer;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyLong;
import static org.mockito.Matchers.eq;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
@ -51,6 +49,7 @@ public class SSLChannelContextTests extends ESTestCase {
private Consumer exceptionHandler;
private SSLDriver sslDriver;
private ByteBuffer readBuffer = ByteBuffer.allocate(1 << 14);
private ByteBuffer writeBuffer = ByteBuffer.allocate(1 << 14);
private int messageLength;
@Before
@ -74,6 +73,7 @@ public class SSLChannelContextTests extends ESTestCase {
when(selector.isOnCurrentThread()).thenReturn(true);
when(selector.getTaskScheduler()).thenReturn(nioTimer);
when(sslDriver.getNetworkReadBuffer()).thenReturn(readBuffer);
when(sslDriver.getNetworkWriteBuffer()).thenReturn(writeBuffer);
ByteBuffer buffer = ByteBuffer.allocate(1 << 14);
when(selector.getIoBuffer()).thenAnswer(invocationOnMock -> {
buffer.clear();
@ -85,7 +85,7 @@ public class SSLChannelContextTests extends ESTestCase {
byte[] bytes = createMessage(messageLength);
when(rawChannel.read(any(ByteBuffer.class))).thenReturn(bytes.length);
doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer);
doAnswer(getAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer);
when(readConsumer.apply(channelBuffer)).thenReturn(messageLength, 0);
@ -100,7 +100,7 @@ public class SSLChannelContextTests extends ESTestCase {
byte[] bytes = createMessage(messageLength * 2);
when(rawChannel.read(any(ByteBuffer.class))).thenReturn(bytes.length);
doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer);
doAnswer(getAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer);
when(readConsumer.apply(channelBuffer)).thenReturn(messageLength, messageLength, 0);
@ -115,7 +115,7 @@ public class SSLChannelContextTests extends ESTestCase {
byte[] bytes = createMessage(messageLength);
when(rawChannel.read(any(ByteBuffer.class))).thenReturn(bytes.length);
doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer);
doAnswer(getAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer);
when(readConsumer.apply(channelBuffer)).thenReturn(0);
@ -173,6 +173,7 @@ public class SSLChannelContextTests extends ESTestCase {
public void testQueuedWritesAreIgnoredWhenNotReadyForAppWrites() {
when(sslDriver.readyForApplicationWrites()).thenReturn(false);
when(sslDriver.hasFlushPending()).thenReturn(false);
when(sslDriver.needsNonApplicationWrite()).thenReturn(false);
context.queueWriteOperation(mock(FlushReadyWrite.class));
@ -180,25 +181,25 @@ public class SSLChannelContextTests extends ESTestCase {
assertFalse(context.readyForFlush());
}
public void testPendingEncryptedFlushMeansWriteInterested() throws Exception {
when(sslDriver.readyForApplicationWrites()).thenReturn(false);
when(sslDriver.needsNonApplicationWrite()).thenReturn(true, false);
doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(any(SSLOutboundBuffer.class));
public void testPendingFlushMeansWriteInterested() {
when(sslDriver.readyForApplicationWrites()).thenReturn(randomBoolean());
when(sslDriver.hasFlushPending()).thenReturn(true);
when(sslDriver.needsNonApplicationWrite()).thenReturn(false);
// Call will put bytes in buffer to flush
context.flushChannel();
assertTrue(context.readyForFlush());
}
public void testNeedsNonAppWritesMeansWriteInterested() {
when(sslDriver.readyForApplicationWrites()).thenReturn(false);
when(sslDriver.hasFlushPending()).thenReturn(false);
when(sslDriver.needsNonApplicationWrite()).thenReturn(true);
assertTrue(context.readyForFlush());
}
public void testNoNonAppWriteInterestInAppMode() {
public void testNotWritesInterestInAppMode() {
when(sslDriver.readyForApplicationWrites()).thenReturn(true);
when(sslDriver.hasFlushPending()).thenReturn(false);
assertFalse(context.readyForFlush());
@ -206,68 +207,66 @@ public class SSLChannelContextTests extends ESTestCase {
}
public void testFirstFlushMustFinishForWriteToContinue() throws Exception {
when(sslDriver.hasFlushPending()).thenReturn(true, true);
when(sslDriver.readyForApplicationWrites()).thenReturn(false);
when(sslDriver.needsNonApplicationWrite()).thenReturn(true);
doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(any(SSLOutboundBuffer.class));
// First call will put bytes in buffer to flush
context.flushChannel();
assertTrue(context.readyForFlush());
// Second call will will not continue generating non-app bytes because they still need to be flushed
context.flushChannel();
assertTrue(context.readyForFlush());
verify(sslDriver, times(1)).nonApplicationWrite(any(SSLOutboundBuffer.class));
verify(sslDriver, times(0)).nonApplicationWrite();
}
public void testNonAppWrites() throws Exception {
when(sslDriver.hasFlushPending()).thenReturn(false, false, true, false, true);
when(sslDriver.needsNonApplicationWrite()).thenReturn(true, true, false);
when(sslDriver.readyForApplicationWrites()).thenReturn(false);
doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(any(SSLOutboundBuffer.class));
when(rawChannel.write(same(selector.getIoBuffer()))).thenReturn(1);
context.flushChannel();
verify(sslDriver, times(2)).nonApplicationWrite(any(SSLOutboundBuffer.class));
verify(sslDriver, times(2)).nonApplicationWrite();
verify(rawChannel, times(2)).write(same(selector.getIoBuffer()));
}
public void testNonAppWritesStopIfBufferNotFullyFlushed() throws Exception {
when(sslDriver.needsNonApplicationWrite()).thenReturn(true);
when(sslDriver.hasFlushPending()).thenReturn(false, false, true, true);
when(sslDriver.needsNonApplicationWrite()).thenReturn(true, true, true, true);
when(sslDriver.readyForApplicationWrites()).thenReturn(false);
doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(any(SSLOutboundBuffer.class));
when(rawChannel.write(same(selector.getIoBuffer()))).thenReturn(0);
context.flushChannel();
verify(sslDriver, times(1)).nonApplicationWrite(any(SSLOutboundBuffer.class));
verify(sslDriver, times(1)).nonApplicationWrite();
verify(rawChannel, times(1)).write(same(selector.getIoBuffer()));
}
public void testQueuedWriteIsFlushedInFlushCall() throws Exception {
ByteBuffer[] buffers = {ByteBuffer.allocate(10)};
FlushReadyWrite flushOperation = new FlushReadyWrite(context, buffers, listener);
FlushReadyWrite flushOperation = mock(FlushReadyWrite.class);
context.queueWriteOperation(flushOperation);
when(flushOperation.getBuffersToWrite()).thenReturn(buffers);
when(flushOperation.getListener()).thenReturn(listener);
when(sslDriver.hasFlushPending()).thenReturn(false, false, false, false);
when(sslDriver.readyForApplicationWrites()).thenReturn(true);
doAnswer(getWriteAnswer(10, true)).when(sslDriver).write(eq(flushOperation), any(SSLOutboundBuffer.class));
when(rawChannel.write(same(selector.getIoBuffer()))).thenReturn(10);
when(sslDriver.applicationWrite(buffers)).thenReturn(10);
when(flushOperation.isFullyFlushed()).thenReturn(false,true);
context.flushChannel();
verify(flushOperation).incrementIndex(10);
verify(rawChannel, times(1)).write(same(selector.getIoBuffer()));
verify(selector).executeListener(listener, null);
assertFalse(context.readyForFlush());
}
public void testPartialFlush() throws IOException {
ByteBuffer[] buffers = {ByteBuffer.allocate(5)};
FlushReadyWrite flushOperation = new FlushReadyWrite(context, buffers, listener);
ByteBuffer[] buffers = {ByteBuffer.allocate(10)};
FlushReadyWrite flushOperation = mock(FlushReadyWrite.class);
context.queueWriteOperation(flushOperation);
when(flushOperation.getBuffersToWrite()).thenReturn(buffers);
when(flushOperation.getListener()).thenReturn(listener);
when(sslDriver.hasFlushPending()).thenReturn(false, false, true);
when(sslDriver.readyForApplicationWrites()).thenReturn(true);
doAnswer(getWriteAnswer(5, true)).when(sslDriver).write(eq(flushOperation), any(SSLOutboundBuffer.class));
when(rawChannel.write(same(selector.getIoBuffer()))).thenReturn(4);
when(sslDriver.applicationWrite(buffers)).thenReturn(5);
when(flushOperation.isFullyFlushed()).thenReturn(false, false);
context.flushChannel();
verify(rawChannel, times(1)).write(same(selector.getIoBuffer()));
@ -280,16 +279,24 @@ public class SSLChannelContextTests extends ESTestCase {
BiConsumer<Void, Exception> listener2 = mock(BiConsumer.class);
ByteBuffer[] buffers1 = {ByteBuffer.allocate(10)};
ByteBuffer[] buffers2 = {ByteBuffer.allocate(5)};
FlushReadyWrite flushOperation1 = new FlushReadyWrite(context, buffers1, listener);
FlushReadyWrite flushOperation2 = new FlushReadyWrite(context, buffers2, listener2);
FlushReadyWrite flushOperation1 = mock(FlushReadyWrite.class);
FlushReadyWrite flushOperation2 = mock(FlushReadyWrite.class);
when(flushOperation1.getBuffersToWrite()).thenReturn(buffers1);
when(flushOperation2.getBuffersToWrite()).thenReturn(buffers2);
when(flushOperation1.getListener()).thenReturn(listener);
when(flushOperation2.getListener()).thenReturn(listener2);
context.queueWriteOperation(flushOperation1);
context.queueWriteOperation(flushOperation2);
when(sslDriver.hasFlushPending()).thenReturn(false, false, false, false, false, true);
when(sslDriver.readyForApplicationWrites()).thenReturn(true);
doAnswer(getWriteAnswer(5, true)).when(sslDriver).write(any(FlushOperation.class), any(SSLOutboundBuffer.class));
when(rawChannel.write(same(selector.getIoBuffer()))).thenReturn(5, 5, 2);
when(sslDriver.applicationWrite(buffers1)).thenReturn(5, 5);
when(sslDriver.applicationWrite(buffers2)).thenReturn(3);
when(flushOperation1.isFullyFlushed()).thenReturn(false, false, true);
when(flushOperation2.isFullyFlushed()).thenReturn(false);
context.flushChannel();
verify(flushOperation1, times(2)).incrementIndex(5);
verify(rawChannel, times(3)).write(same(selector.getIoBuffer()));
verify(selector).executeListener(listener, null);
verify(selector, times(0)).executeListener(listener2, null);
@ -297,27 +304,29 @@ public class SSLChannelContextTests extends ESTestCase {
}
public void testWhenIOExceptionThrownListenerIsCalled() throws IOException {
ByteBuffer[] buffers = {ByteBuffer.allocate(5)};
FlushReadyWrite flushOperation = new FlushReadyWrite(context, buffers, listener);
ByteBuffer[] buffers = {ByteBuffer.allocate(10)};
FlushReadyWrite flushOperation = mock(FlushReadyWrite.class);
context.queueWriteOperation(flushOperation);
IOException exception = new IOException();
when(flushOperation.getBuffersToWrite()).thenReturn(buffers);
when(flushOperation.getListener()).thenReturn(listener);
when(sslDriver.hasFlushPending()).thenReturn(false, false);
when(sslDriver.readyForApplicationWrites()).thenReturn(true);
doAnswer(getWriteAnswer(5, true)).when(sslDriver).write(eq(flushOperation), any(SSLOutboundBuffer.class));
when(sslDriver.applicationWrite(buffers)).thenReturn(5);
when(rawChannel.write(any(ByteBuffer.class))).thenThrow(exception);
when(flushOperation.isFullyFlushed()).thenReturn(false);
expectThrows(IOException.class, () -> context.flushChannel());
verify(flushOperation).incrementIndex(5);
verify(selector).executeFailedListener(listener, exception);
assertFalse(context.readyForFlush());
}
public void testWriteIOExceptionMeansChannelReadyToClose() throws Exception {
when(sslDriver.readyForApplicationWrites()).thenReturn(false);
when(sslDriver.hasFlushPending()).thenReturn(true);
when(sslDriver.needsNonApplicationWrite()).thenReturn(true);
doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(any(SSLOutboundBuffer.class));
context.flushChannel();
when(sslDriver.readyForApplicationWrites()).thenReturn(false);
when(rawChannel.write(any(ByteBuffer.class))).thenThrow(new IOException());
assertFalse(context.selectorShouldClose());
@ -404,27 +413,7 @@ public class SSLChannelContextTests extends ESTestCase {
}
}
private Answer<Integer> getWriteAnswer(int bytesToEncrypt, boolean isApp) {
return invocationOnMock -> {
SSLOutboundBuffer outboundBuffer;
if (isApp) {
outboundBuffer = (SSLOutboundBuffer) invocationOnMock.getArguments()[1];
} else {
outboundBuffer = (SSLOutboundBuffer) invocationOnMock.getArguments()[0];
}
ByteBuffer byteBuffer = outboundBuffer.nextWriteBuffer(bytesToEncrypt + 1);
for (int i = 0; i < bytesToEncrypt; ++i) {
byteBuffer.put((byte) i);
}
outboundBuffer.incrementEncryptedBytes(bytesToEncrypt);
if (isApp) {
((FlushOperation) invocationOnMock.getArguments()[0]).incrementIndex(bytesToEncrypt);
}
return bytesToEncrypt;
};
}
private Answer getReadAnswerForBytes(byte[] bytes) {
private Answer getAnswerForBytes(byte[] bytes) {
return invocationOnMock -> {
InboundChannelBuffer buffer = (InboundChannelBuffer) invocationOnMock.getArguments()[0];
buffer.ensureCapacity(buffer.getIndex() + bytes.length);

View File

@ -6,9 +6,7 @@
package org.elasticsearch.xpack.security.transport.nio;
import org.elasticsearch.bootstrap.JavaVersion;
import org.elasticsearch.nio.FlushOperation;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.Page;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ssl.CertParsingUtils;
import org.elasticsearch.xpack.core.ssl.PemUtils;
@ -30,7 +28,8 @@ import java.util.function.Supplier;
public class SSLDriverTests extends ESTestCase {
private final Supplier<Page> pageSupplier = () -> new Page(ByteBuffer.allocate(1 << 14), () -> {});
private final Supplier<InboundChannelBuffer.Page> pageSupplier =
() -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14), () -> {});
private InboundChannelBuffer serverBuffer = new InboundChannelBuffer(pageSupplier);
private InboundChannelBuffer clientBuffer = new InboundChannelBuffer(pageSupplier);
private InboundChannelBuffer genericBuffer = new InboundChannelBuffer(pageSupplier);
@ -142,6 +141,10 @@ public class SSLDriverTests extends ESTestCase {
boolean expectedMessage = oldExpected.equals(sslException.getMessage()) || jdk11Expected.equals(sslException.getMessage());
assertTrue("Unexpected exception message: " + sslException.getMessage(), expectedMessage);
// In JDK11 we need an non-application write
if (serverDriver.needsNonApplicationWrite()) {
serverDriver.nonApplicationWrite();
}
// Prior to JDK11 we still need to send a close alert
if (serverDriver.isClosed() == false) {
failedCloseAlert(serverDriver, clientDriver, Arrays.asList("Received fatal alert: protocol_version",
@ -163,7 +166,10 @@ public class SSLDriverTests extends ESTestCase {
SSLDriver serverDriver = getDriver(serverEngine, false);
expectThrows(SSLException.class, () -> handshake(clientDriver, serverDriver));
// In JDK11 we need an non-application write
if (serverDriver.needsNonApplicationWrite()) {
serverDriver.nonApplicationWrite();
}
// Prior to JDK11 we still need to send a close alert
if (serverDriver.isClosed() == false) {
List<String> messages = Arrays.asList("Received fatal alert: handshake_failure",
@ -186,6 +192,8 @@ public class SSLDriverTests extends ESTestCase {
sendHandshakeMessages(clientDriver, serverDriver);
sendHandshakeMessages(serverDriver, clientDriver);
sendData(clientDriver, serverDriver);
assertTrue(clientDriver.isHandshaking());
assertTrue(serverDriver.isHandshaking());
@ -219,6 +227,8 @@ public class SSLDriverTests extends ESTestCase {
sendHandshakeMessages(clientDriver, serverDriver);
sendHandshakeMessages(serverDriver, clientDriver);
sendData(clientDriver, serverDriver);
assertTrue(clientDriver.isHandshaking());
assertTrue(serverDriver.isHandshaking());
@ -296,12 +306,12 @@ public class SSLDriverTests extends ESTestCase {
}
private void sendNonApplicationWrites(SSLDriver sendDriver, SSLDriver receiveDriver) throws SSLException {
SSLOutboundBuffer outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n)));
while (sendDriver.needsNonApplicationWrite() || outboundBuffer.hasEncryptedBytesToFlush()) {
if (outboundBuffer.hasEncryptedBytesToFlush()) {
sendData(outboundBuffer.buildNetworkFlushOperation(), receiveDriver);
} else {
sendDriver.nonApplicationWrite(outboundBuffer);
while (sendDriver.needsNonApplicationWrite() || sendDriver.hasFlushPending()) {
if (sendDriver.hasFlushPending() == false) {
sendDriver.nonApplicationWrite();
}
if (sendDriver.hasFlushPending()) {
sendData(sendDriver, receiveDriver, true);
}
}
}
@ -316,7 +326,7 @@ public class SSLDriverTests extends ESTestCase {
serverDriver.init();
}
assertTrue(clientDriver.needsNonApplicationWrite());
assertTrue(clientDriver.needsNonApplicationWrite() || clientDriver.hasFlushPending());
assertFalse(serverDriver.needsNonApplicationWrite());
sendHandshakeMessages(clientDriver, serverDriver);
@ -340,51 +350,58 @@ public class SSLDriverTests extends ESTestCase {
}
private void sendHandshakeMessages(SSLDriver sendDriver, SSLDriver receiveDriver) throws IOException {
assertTrue(sendDriver.needsNonApplicationWrite());
assertTrue(sendDriver.needsNonApplicationWrite() || sendDriver.hasFlushPending());
SSLOutboundBuffer outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n)));
while (sendDriver.needsNonApplicationWrite() || outboundBuffer.hasEncryptedBytesToFlush()) {
if (outboundBuffer.hasEncryptedBytesToFlush()) {
sendData(outboundBuffer.buildNetworkFlushOperation(), receiveDriver);
while (sendDriver.needsNonApplicationWrite() || sendDriver.hasFlushPending()) {
if (sendDriver.hasFlushPending() == false) {
sendDriver.nonApplicationWrite();
}
if (sendDriver.isHandshaking()) {
assertTrue(sendDriver.hasFlushPending());
sendData(sendDriver, receiveDriver);
assertFalse(sendDriver.hasFlushPending());
receiveDriver.read(genericBuffer);
} else {
sendDriver.nonApplicationWrite(outboundBuffer);
}
}
if (receiveDriver.isHandshaking()) {
assertTrue(receiveDriver.needsNonApplicationWrite());
assertTrue(receiveDriver.needsNonApplicationWrite() || receiveDriver.hasFlushPending());
}
}
private void sendAppData(SSLDriver sendDriver, SSLDriver receiveDriver, ByteBuffer[] message) throws IOException {
assertFalse(sendDriver.needsNonApplicationWrite());
int bytesToEncrypt = Arrays.stream(message).mapToInt(Buffer::remaining).sum();
SSLOutboundBuffer outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n)));
FlushOperation flushOperation = new FlushOperation(message, (r, l) -> {});
int bytesEncrypted = 0;
while (bytesToEncrypt > bytesEncrypted) {
bytesEncrypted += sendDriver.write(flushOperation, outboundBuffer);
sendData(outboundBuffer.buildNetworkFlushOperation(), receiveDriver);
bytesEncrypted += sendDriver.applicationWrite(message);
sendData(sendDriver, receiveDriver);
}
}
private void sendData(FlushOperation flushOperation, SSLDriver receiveDriver) {
private void sendData(SSLDriver sendDriver, SSLDriver receiveDriver) {
sendData(sendDriver, receiveDriver, randomBoolean());
}
private void sendData(SSLDriver sendDriver, SSLDriver receiveDriver, boolean partial) {
ByteBuffer writeBuffer = sendDriver.getNetworkWriteBuffer();
ByteBuffer readBuffer = receiveDriver.getNetworkReadBuffer();
ByteBuffer[] writeBuffers = flushOperation.getBuffersToWrite();
int bytesToEncrypt = Arrays.stream(writeBuffers).mapToInt(Buffer::remaining).sum();
assert bytesToEncrypt < readBuffer.capacity() : "Flush operation must be less that read buffer";
assert writeBuffers.length > 0 : "No write buffers";
for (ByteBuffer writeBuffer : writeBuffers) {
int written = writeBuffer.remaining();
if (partial) {
int initialLimit = writeBuffer.limit();
int bytesToWrite = writeBuffer.remaining() / (randomInt(2) + 2);
writeBuffer.limit(writeBuffer.position() + bytesToWrite);
readBuffer.put(writeBuffer);
flushOperation.incrementIndex(written);
}
writeBuffer.limit(initialLimit);
assertTrue(sendDriver.hasFlushPending());
readBuffer.put(writeBuffer);
assertFalse(sendDriver.hasFlushPending());
assertTrue(flushOperation.isFullyFlushed());
} else {
readBuffer.put(writeBuffer);
assertFalse(sendDriver.hasFlushPending());
}
}
private SSLDriver getDriver(SSLEngine engine, boolean isClient) {