Remove dedicated SSL network write buffer (#41654)

This is related to #27260. Currently for the SSLDriver we allocate a
dedicated network write buffer and encrypt the data into that buffer one
buffer at a time. This requires constantly switching between encrypting
and flushing. This commit adds a dedicated outbound buffer for SSL
operations that will internally allocate new packet sized buffers as
they are need (for writing encrypted data). This allows us to totally
encrypt an operation before writing it to the network. Eventually it can
be hooked up to buffer recycling.

This commit also backports the following commit:

Handle WRAP ops during SSL read

It is possible that a WRAP operation can occur while decrypting
handshake data in TLS 1.3. The SSLDriver does not currently handle this
well as it does not have access to the outbound buffer during read call.
This commit moves the buffer into the Driver to fix this issue. Data
wrapped during a read call will be queued for writing after the read
call is complete.
This commit is contained in:
Tim Brooks 2019-04-29 17:59:13 -06:00 committed by GitHub
parent 10ab838106
commit df3ef66294
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 485 additions and 331 deletions

View File

@ -25,6 +25,8 @@ import java.util.function.BiConsumer;
public class FlushOperation {
private static final ByteBuffer[] EMPTY_ARRAY = new ByteBuffer[0];
private final BiConsumer<Void, Exception> listener;
private final ByteBuffer[] buffers;
private final int[] offsets;
@ -61,19 +63,38 @@ public class FlushOperation {
}
public ByteBuffer[] getBuffersToWrite() {
return getBuffersToWrite(length);
}
public ByteBuffer[] getBuffersToWrite(int maxBytes) {
final int index = Arrays.binarySearch(offsets, internalIndex);
int offsetIndex = index < 0 ? (-(index + 1)) - 1 : index;
final int offsetIndex = index < 0 ? (-(index + 1)) - 1 : index;
final int finalIndex = Arrays.binarySearch(offsets, Math.min(internalIndex + maxBytes, length));
final int finalOffsetIndex = finalIndex < 0 ? (-(finalIndex + 1)) - 1 : finalIndex;
ByteBuffer[] postIndexBuffers = new ByteBuffer[buffers.length - offsetIndex];
int nBuffers = (finalOffsetIndex - offsetIndex) + 1;
int firstBufferPosition = internalIndex - offsets[offsetIndex];
ByteBuffer firstBuffer = buffers[offsetIndex].duplicate();
firstBuffer.position(internalIndex - offsets[offsetIndex]);
postIndexBuffers[0] = firstBuffer;
int j = 1;
for (int i = (offsetIndex + 1); i < buffers.length; ++i) {
postIndexBuffers[j++] = buffers[i].duplicate();
firstBuffer.position(firstBufferPosition);
if (nBuffers == 1 && firstBuffer.remaining() == 0) {
return EMPTY_ARRAY;
}
ByteBuffer[] postIndexBuffers = new ByteBuffer[nBuffers];
postIndexBuffers[0] = firstBuffer;
int finalOffset = offsetIndex + nBuffers;
int nBytes = firstBuffer.remaining();
int j = 1;
for (int i = (offsetIndex + 1); i < finalOffset; ++i) {
ByteBuffer buffer = buffers[i].duplicate();
nBytes += buffer.remaining();
postIndexBuffers[j++] = buffer;
}
int excessBytes = Math.max(0, nBytes - maxBytes);
ByteBuffer lastBuffer = postIndexBuffers[postIndexBuffers.length - 1];
lastBuffer.limit(lastBuffer.limit() - excessBytes);
return postIndexBuffers;
}
}

View File

@ -27,7 +27,7 @@ public class FlushReadyWrite extends FlushOperation implements WriteOperation {
private final SocketChannelContext channelContext;
private final ByteBuffer[] buffers;
FlushReadyWrite(SocketChannelContext channelContext, ByteBuffer[] buffers, BiConsumer<Void, Exception> listener) {
public FlushReadyWrite(SocketChannelContext channelContext, ByteBuffer[] buffers, BiConsumer<Void, Exception> listener) {
super(buffers, listener);
this.channelContext = channelContext;
this.buffers = buffers;

View File

@ -19,7 +19,6 @@
package org.elasticsearch.nio;
import org.elasticsearch.common.util.concurrent.AbstractRefCounted;
import org.elasticsearch.nio.utils.ExceptionsHelper;
import java.nio.ByteBuffer;
@ -140,11 +139,11 @@ public final class InboundChannelBuffer implements AutoCloseable {
ByteBuffer[] buffers = new ByteBuffer[pageCount];
Iterator<Page> pageIterator = pages.iterator();
ByteBuffer firstBuffer = pageIterator.next().byteBuffer.duplicate();
ByteBuffer firstBuffer = pageIterator.next().byteBuffer().duplicate();
firstBuffer.position(firstBuffer.position() + offset);
buffers[0] = firstBuffer;
for (int i = 1; i < buffers.length; i++) {
buffers[i] = pageIterator.next().byteBuffer.duplicate();
buffers[i] = pageIterator.next().byteBuffer().duplicate();
}
if (finalLimit != 0) {
buffers[buffers.length - 1].limit(finalLimit);
@ -180,14 +179,14 @@ public final class InboundChannelBuffer implements AutoCloseable {
Page[] pages = new Page[pageCount];
Iterator<Page> pageIterator = this.pages.iterator();
Page firstPage = pageIterator.next().duplicate();
ByteBuffer firstBuffer = firstPage.byteBuffer;
ByteBuffer firstBuffer = firstPage.byteBuffer();
firstBuffer.position(firstBuffer.position() + offset);
pages[0] = firstPage;
for (int i = 1; i < pages.length; i++) {
pages[i] = pageIterator.next().duplicate();
}
if (finalLimit != 0) {
pages[pages.length - 1].byteBuffer.limit(finalLimit);
pages[pages.length - 1].byteBuffer().limit(finalLimit);
}
return pages;
@ -217,9 +216,9 @@ public final class InboundChannelBuffer implements AutoCloseable {
ByteBuffer[] buffers = new ByteBuffer[pages.size() - pageIndex];
Iterator<Page> pageIterator = pages.descendingIterator();
for (int i = buffers.length - 1; i > 0; --i) {
buffers[i] = pageIterator.next().byteBuffer.duplicate();
buffers[i] = pageIterator.next().byteBuffer().duplicate();
}
ByteBuffer firstPostIndexBuffer = pageIterator.next().byteBuffer.duplicate();
ByteBuffer firstPostIndexBuffer = pageIterator.next().byteBuffer().duplicate();
firstPostIndexBuffer.position(firstPostIndexBuffer.position() + indexInPage);
buffers[0] = firstPostIndexBuffer;
@ -268,53 +267,4 @@ public final class InboundChannelBuffer implements AutoCloseable {
private int indexInPage(long index) {
return (int) (index & PAGE_MASK);
}
public static class Page implements AutoCloseable {
private final ByteBuffer byteBuffer;
// This is reference counted as some implementations want to retain the byte pages by calling
// sliceAndRetainPagesTo. With reference counting we can increment the reference count, return the
// pages, and safely close them when this channel buffer is done with them. The reference count
// would be 1 at that point, meaning that the pages will remain until the implementation closes
// theirs.
private final RefCountedCloseable refCountedCloseable;
public Page(ByteBuffer byteBuffer, Runnable closeable) {
this(byteBuffer, new RefCountedCloseable(closeable));
}
private Page(ByteBuffer byteBuffer, RefCountedCloseable refCountedCloseable) {
this.byteBuffer = byteBuffer;
this.refCountedCloseable = refCountedCloseable;
}
private Page duplicate() {
refCountedCloseable.incRef();
return new Page(byteBuffer.duplicate(), refCountedCloseable);
}
public ByteBuffer getByteBuffer() {
return byteBuffer;
}
@Override
public void close() {
refCountedCloseable.decRef();
}
private static class RefCountedCloseable extends AbstractRefCounted {
private final Runnable closeable;
private RefCountedCloseable(Runnable closeable) {
super("byte array page");
this.closeable = closeable;
}
@Override
protected void closeInternal() {
closeable.run();
}
}
}
}

View File

@ -0,0 +1,89 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.nio;
import org.elasticsearch.common.util.concurrent.AbstractRefCounted;
import java.io.Closeable;
import java.nio.ByteBuffer;
public class Page implements Closeable {
private final ByteBuffer byteBuffer;
// This is reference counted as some implementations want to retain the byte pages by calling
// duplicate. With reference counting we can increment the reference count, return a new page,
// and safely close the pages independently. The closeable will not be called until each page is
// released.
private final RefCountedCloseable refCountedCloseable;
public Page(ByteBuffer byteBuffer) {
this(byteBuffer, () -> {});
}
public Page(ByteBuffer byteBuffer, Runnable closeable) {
this(byteBuffer, new RefCountedCloseable(closeable));
}
private Page(ByteBuffer byteBuffer, RefCountedCloseable refCountedCloseable) {
this.byteBuffer = byteBuffer;
this.refCountedCloseable = refCountedCloseable;
}
/**
* Duplicates this page and increments the reference count. The new page must be closed independently
* of the original page.
*
* @return the new page
*/
public Page duplicate() {
refCountedCloseable.incRef();
return new Page(byteBuffer.duplicate(), refCountedCloseable);
}
/**
* Returns the {@link ByteBuffer} for this page. Modifications to the limits, positions, etc of the
* buffer will also mutate this page. Call {@link ByteBuffer#duplicate()} to avoid mutating the page.
*
* @return the byte buffer
*/
public ByteBuffer byteBuffer() {
return byteBuffer;
}
@Override
public void close() {
refCountedCloseable.decRef();
}
private static class RefCountedCloseable extends AbstractRefCounted {
private final Runnable closeable;
private RefCountedCloseable(Runnable closeable) {
super("byte array page");
this.closeable = closeable;
}
@Override
protected void closeInternal() {
closeable.run();
}
}
}

View File

@ -325,7 +325,7 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
ioBuffer.clear();
ioBuffer.limit(Math.min(WRITE_LIMIT, ioBuffer.limit()));
int j = 0;
ByteBuffer[] buffers = flushOperation.getBuffersToWrite();
ByteBuffer[] buffers = flushOperation.getBuffersToWrite(WRITE_LIMIT);
while (j < buffers.length && ioBuffer.remaining() > 0) {
ByteBuffer buffer = buffers[j++];
copyBytes(buffer, ioBuffer);

View File

@ -31,6 +31,7 @@ import java.util.function.BiConsumer;
import java.util.function.Consumer;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
@ -168,7 +169,7 @@ public class BytesChannelContextTests extends ESTestCase {
assertTrue(context.readyForFlush());
when(flushOperation.getBuffersToWrite()).thenReturn(buffers);
when(flushOperation.getBuffersToWrite(anyInt())).thenReturn(buffers);
when(flushOperation.isFullyFlushed()).thenReturn(false, true);
when(flushOperation.getListener()).thenReturn(listener);
context.flushChannel();
@ -187,7 +188,7 @@ public class BytesChannelContextTests extends ESTestCase {
assertTrue(context.readyForFlush());
when(flushOperation.isFullyFlushed()).thenReturn(false);
when(flushOperation.getBuffersToWrite()).thenReturn(new ByteBuffer[] {ByteBuffer.allocate(3)});
when(flushOperation.getBuffersToWrite(anyInt())).thenReturn(new ByteBuffer[] {ByteBuffer.allocate(3)});
context.flushChannel();
verify(listener, times(0)).accept(null, null);
@ -201,8 +202,8 @@ public class BytesChannelContextTests extends ESTestCase {
BiConsumer<Void, Exception> listener2 = mock(BiConsumer.class);
FlushReadyWrite flushOperation1 = mock(FlushReadyWrite.class);
FlushReadyWrite flushOperation2 = mock(FlushReadyWrite.class);
when(flushOperation1.getBuffersToWrite()).thenReturn(new ByteBuffer[] {ByteBuffer.allocate(3)});
when(flushOperation2.getBuffersToWrite()).thenReturn(new ByteBuffer[] {ByteBuffer.allocate(3)});
when(flushOperation1.getBuffersToWrite(anyInt())).thenReturn(new ByteBuffer[] {ByteBuffer.allocate(3)});
when(flushOperation2.getBuffersToWrite(anyInt())).thenReturn(new ByteBuffer[] {ByteBuffer.allocate(3)});
when(flushOperation1.getListener()).thenReturn(listener);
when(flushOperation2.getListener()).thenReturn(listener2);
@ -237,7 +238,7 @@ public class BytesChannelContextTests extends ESTestCase {
assertTrue(context.readyForFlush());
IOException exception = new IOException();
when(flushOperation.getBuffersToWrite()).thenReturn(buffers);
when(flushOperation.getBuffersToWrite(anyInt())).thenReturn(buffers);
when(rawChannel.write(any(ByteBuffer.class))).thenThrow(exception);
when(flushOperation.getListener()).thenReturn(listener);
expectThrows(IOException.class, () -> context.flushChannel());
@ -252,7 +253,7 @@ public class BytesChannelContextTests extends ESTestCase {
context.queueWriteOperation(flushOperation);
IOException exception = new IOException();
when(flushOperation.getBuffersToWrite()).thenReturn(buffers);
when(flushOperation.getBuffersToWrite(anyInt())).thenReturn(buffers);
when(rawChannel.write(any(ByteBuffer.class))).thenThrow(exception);
assertFalse(context.selectorShouldClose());

View File

@ -65,29 +65,45 @@ public class FlushOperationTests extends ESTestCase {
ByteBuffer[] byteBuffers = writeOp.getBuffersToWrite();
assertEquals(3, byteBuffers.length);
assertEquals(5, byteBuffers[0].remaining());
ByteBuffer[] byteBuffersWithLimit = writeOp.getBuffersToWrite(10);
assertEquals(2, byteBuffersWithLimit.length);
assertEquals(5, byteBuffersWithLimit[0].remaining());
assertEquals(5, byteBuffersWithLimit[1].remaining());
writeOp.incrementIndex(5);
assertFalse(writeOp.isFullyFlushed());
byteBuffers = writeOp.getBuffersToWrite();
assertEquals(2, byteBuffers.length);
assertEquals(15, byteBuffers[0].remaining());
assertEquals(3, byteBuffers[1].remaining());
byteBuffersWithLimit = writeOp.getBuffersToWrite(10);
assertEquals(1, byteBuffersWithLimit.length);
assertEquals(10, byteBuffersWithLimit[0].remaining());
writeOp.incrementIndex(2);
assertFalse(writeOp.isFullyFlushed());
byteBuffers = writeOp.getBuffersToWrite();
assertEquals(2, byteBuffers.length);
assertEquals(13, byteBuffers[0].remaining());
assertEquals(3, byteBuffers[1].remaining());
byteBuffersWithLimit = writeOp.getBuffersToWrite(10);
assertEquals(1, byteBuffersWithLimit.length);
assertEquals(10, byteBuffersWithLimit[0].remaining());
writeOp.incrementIndex(15);
assertFalse(writeOp.isFullyFlushed());
byteBuffers = writeOp.getBuffersToWrite();
assertEquals(1, byteBuffers.length);
assertEquals(1, byteBuffers[0].remaining());
byteBuffersWithLimit = writeOp.getBuffersToWrite(10);
assertEquals(1, byteBuffersWithLimit.length);
assertEquals(1, byteBuffersWithLimit[0].remaining());
writeOp.incrementIndex(1);
assertTrue(writeOp.isFullyFlushed());
byteBuffers = writeOp.getBuffersToWrite();
assertEquals(1, byteBuffers.length);
assertEquals(0, byteBuffers[0].remaining());
assertEquals(0, byteBuffers.length);
byteBuffersWithLimit = writeOp.getBuffersToWrite(10);
assertEquals(0, byteBuffersWithLimit.length);
}
}

View File

@ -30,8 +30,8 @@ import java.util.function.Supplier;
public class InboundChannelBufferTests extends ESTestCase {
private static final int PAGE_SIZE = PageCacheRecycler.PAGE_SIZE_IN_BYTES;
private final Supplier<InboundChannelBuffer.Page> defaultPageSupplier = () ->
new InboundChannelBuffer.Page(ByteBuffer.allocate(PageCacheRecycler.BYTE_PAGE_SIZE), () -> {
private final Supplier<Page> defaultPageSupplier = () ->
new Page(ByteBuffer.allocate(PageCacheRecycler.BYTE_PAGE_SIZE), () -> {
});
public void testNewBufferNoPages() {
@ -126,10 +126,10 @@ public class InboundChannelBufferTests extends ESTestCase {
public void testReleaseClosesPages() {
ConcurrentLinkedQueue<AtomicBoolean> queue = new ConcurrentLinkedQueue<>();
Supplier<InboundChannelBuffer.Page> supplier = () -> {
Supplier<Page> supplier = () -> {
AtomicBoolean atomicBoolean = new AtomicBoolean();
queue.add(atomicBoolean);
return new InboundChannelBuffer.Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true));
return new Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true));
};
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(supplier);
channelBuffer.ensureCapacity(PAGE_SIZE * 4);
@ -153,10 +153,10 @@ public class InboundChannelBufferTests extends ESTestCase {
public void testClose() {
ConcurrentLinkedQueue<AtomicBoolean> queue = new ConcurrentLinkedQueue<>();
Supplier<InboundChannelBuffer.Page> supplier = () -> {
Supplier<Page> supplier = () -> {
AtomicBoolean atomicBoolean = new AtomicBoolean();
queue.add(atomicBoolean);
return new InboundChannelBuffer.Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true));
return new Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true));
};
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(supplier);
channelBuffer.ensureCapacity(PAGE_SIZE * 4);
@ -178,10 +178,10 @@ public class InboundChannelBufferTests extends ESTestCase {
public void testCloseRetainedPages() {
ConcurrentLinkedQueue<AtomicBoolean> queue = new ConcurrentLinkedQueue<>();
Supplier<InboundChannelBuffer.Page> supplier = () -> {
Supplier<Page> supplier = () -> {
AtomicBoolean atomicBoolean = new AtomicBoolean();
queue.add(atomicBoolean);
return new InboundChannelBuffer.Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true));
return new Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true));
};
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(supplier);
channelBuffer.ensureCapacity(PAGE_SIZE * 4);
@ -192,7 +192,7 @@ public class InboundChannelBufferTests extends ESTestCase {
assertFalse(closedRef.get());
}
InboundChannelBuffer.Page[] pages = channelBuffer.sliceAndRetainPagesTo(PAGE_SIZE * 2);
Page[] pages = channelBuffer.sliceAndRetainPagesTo(PAGE_SIZE * 2);
pages[1].close();

View File

@ -285,7 +285,7 @@ public class SocketChannelContextTests extends ESTestCase {
when(channel.getRawChannel()).thenReturn(realChannel);
when(channel.isOpen()).thenReturn(true);
Runnable closer = mock(Runnable.class);
Supplier<InboundChannelBuffer.Page> pageSupplier = () -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14), closer);
Supplier<Page> pageSupplier = () -> new Page(ByteBuffer.allocate(1 << 14), closer);
InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier);
buffer.ensureCapacity(1);
TestSocketChannelContext context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, buffer);

View File

@ -29,7 +29,7 @@ import io.netty.channel.ChannelPromise;
import io.netty.channel.embedded.EmbeddedChannel;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.nio.FlushOperation;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.Page;
import org.elasticsearch.nio.WriteOperation;
import java.nio.ByteBuffer;
@ -97,7 +97,7 @@ class NettyAdaptor {
return byteBuf.readerIndex() - initialReaderIndex;
}
public int read(InboundChannelBuffer.Page[] pages) {
public int read(Page[] pages) {
ByteBuf byteBuf = PagedByteBuf.byteBufFromPages(pages);
int readableBytes = byteBuf.readableBytes();
nettyChannel.writeInbound(byteBuf);

View File

@ -43,6 +43,7 @@ import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioGroup;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.Page;
import org.elasticsearch.nio.ServerChannelContext;
import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.rest.RestUtils;
@ -205,9 +206,9 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport {
@Override
public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException {
NioHttpChannel httpChannel = new NioHttpChannel(channel);
java.util.function.Supplier<InboundChannelBuffer.Page> pageSupplier = () -> {
java.util.function.Supplier<Page> pageSupplier = () -> {
Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close);
return new Page(ByteBuffer.wrap(bytes.v()), bytes::close);
};
HttpReadWriteHandler httpReadWritePipeline = new HttpReadWriteHandler(httpChannel,NioHttpServerTransport.this,
handlingSettings, corsConfig);

View File

@ -24,7 +24,7 @@ import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.buffer.UnpooledHeapByteBuf;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.Page;
import java.nio.ByteBuffer;
import java.util.ArrayList;
@ -39,7 +39,7 @@ public class PagedByteBuf extends UnpooledHeapByteBuf {
this.releasable = releasable;
}
static ByteBuf byteBufFromPages(InboundChannelBuffer.Page[] pages) {
static ByteBuf byteBufFromPages(Page[] pages) {
int componentCount = pages.length;
if (componentCount == 0) {
return Unpooled.EMPTY_BUFFER;
@ -48,15 +48,15 @@ public class PagedByteBuf extends UnpooledHeapByteBuf {
} else {
int maxComponents = Math.max(16, componentCount);
final List<ByteBuf> components = new ArrayList<>(componentCount);
for (InboundChannelBuffer.Page page : pages) {
for (Page page : pages) {
components.add(byteBufFromPage(page));
}
return new CompositeByteBuf(UnpooledByteBufAllocator.DEFAULT, false, maxComponents, components);
}
}
private static ByteBuf byteBufFromPage(InboundChannelBuffer.Page page) {
ByteBuffer buffer = page.getByteBuffer();
private static ByteBuf byteBufFromPage(Page page) {
ByteBuffer buffer = page.byteBuffer();
assert buffer.isDirect() == false && buffer.hasArray() : "Must be a heap buffer with an array";
int offset = buffer.arrayOffset() + buffer.position();
PagedByteBuf newByteBuf = new PagedByteBuf(buffer.array(), page::close);

View File

@ -36,6 +36,7 @@ import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioGroup;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.Page;
import org.elasticsearch.nio.ServerChannelContext;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TcpTransport;
@ -157,9 +158,9 @@ public class NioTransport extends TcpTransport {
@Override
public NioTcpChannel createChannel(NioSelector selector, SocketChannel channel) {
NioTcpChannel nioChannel = new NioTcpChannel(isClient == false, profileName, channel);
Supplier<InboundChannelBuffer.Page> pageSupplier = () -> {
Supplier<Page> pageSupplier = () -> {
Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close);
return new Page(ByteBuffer.wrap(bytes.v()), bytes::close);
};
TcpReadWriteHandler readWriteHandler = new TcpReadWriteHandler(nioChannel, NioTransport.this);
Consumer<Exception> exceptionHandler = (e) -> onException(nioChannel, e);

View File

@ -20,7 +20,7 @@
package org.elasticsearch.http.nio;
import io.netty.buffer.ByteBuf;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.Page;
import org.elasticsearch.test.ESTestCase;
import java.nio.ByteBuffer;
@ -32,12 +32,12 @@ public class PagedByteBufTests extends ESTestCase {
public void testReleasingPage() {
AtomicInteger integer = new AtomicInteger(0);
int pageCount = randomInt(10) + 1;
ArrayList<InboundChannelBuffer.Page> pages = new ArrayList<>();
ArrayList<Page> pages = new ArrayList<>();
for (int i = 0; i < pageCount; ++i) {
pages.add(new InboundChannelBuffer.Page(ByteBuffer.allocate(10), integer::incrementAndGet));
pages.add(new Page(ByteBuffer.allocate(10), integer::incrementAndGet));
}
ByteBuf byteBuf = PagedByteBuf.byteBufFromPages(pages.toArray(new InboundChannelBuffer.Page[0]));
ByteBuf byteBuf = PagedByteBuf.byteBufFromPages(pages.toArray(new Page[0]));
assertEquals(0, integer.get());
byteBuf.retain();
@ -62,9 +62,9 @@ public class PagedByteBufTests extends ESTestCase {
bytes2[i - 10] = (byte) i;
}
InboundChannelBuffer.Page[] pages = new InboundChannelBuffer.Page[2];
pages[0] = new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes1), () -> {});
pages[1] = new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes2), () -> {});
Page[] pages = new Page[2];
pages[0] = new Page(ByteBuffer.wrap(bytes1), () -> {});
pages[1] = new Page(ByteBuffer.wrap(bytes2), () -> {});
ByteBuf byteBuf = PagedByteBuf.byteBufFromPages(pages);
assertEquals(20, byteBuf.readableBytes());
@ -73,13 +73,13 @@ public class PagedByteBufTests extends ESTestCase {
assertEquals((byte) i, byteBuf.getByte(i));
}
InboundChannelBuffer.Page[] pages2 = new InboundChannelBuffer.Page[2];
Page[] pages2 = new Page[2];
ByteBuffer firstBuffer = ByteBuffer.wrap(bytes1);
firstBuffer.position(2);
ByteBuffer secondBuffer = ByteBuffer.wrap(bytes2);
secondBuffer.limit(8);
pages2[0] = new InboundChannelBuffer.Page(firstBuffer, () -> {});
pages2[1] = new InboundChannelBuffer.Page(secondBuffer, () -> {});
pages2[0] = new Page(firstBuffer, () -> {});
pages2[1] = new Page(secondBuffer, () -> {});
ByteBuf byteBuf2 = PagedByteBuf.byteBufFromPages(pages2);
assertEquals(16, byteBuf2.readableBytes());

View File

@ -41,6 +41,7 @@ import org.elasticsearch.nio.NioSelectorGroup;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioServerSocketChannel;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.Page;
import org.elasticsearch.nio.ServerChannelContext;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.ConnectionProfile;
@ -191,9 +192,9 @@ public class MockNioTransport extends TcpTransport {
@Override
public MockSocketChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException {
MockSocketChannel nioChannel = new MockSocketChannel(isClient == false, profileName, channel);
Supplier<InboundChannelBuffer.Page> pageSupplier = () -> {
Supplier<Page> pageSupplier = () -> {
Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close);
return new Page(ByteBuffer.wrap(bytes.v()), bytes::close);
};
MockTcpReadWriteHandler readWriteHandler = new MockTcpReadWriteHandler(nioChannel, MockNioTransport.this);
BytesChannelContext context = new BytesChannelContext(nioChannel, selector, (e) -> exceptionCaught(nioChannel, e),

View File

@ -9,14 +9,16 @@ import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.core.internal.io.IOUtils;
import org.elasticsearch.nio.FlushOperation;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.ReadWriteHandler;
import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.WriteOperation;
import javax.net.ssl.SSLEngine;
import java.io.IOException;
import java.nio.channels.ClosedChannelException;
import java.util.LinkedList;
import java.util.concurrent.TimeUnit;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
@ -34,6 +36,7 @@ public final class SSLChannelContext extends SocketChannelContext {
private static final Runnable DEFAULT_TIMEOUT_CANCELLER = () -> {};
private final SSLDriver sslDriver;
private final LinkedList<FlushOperation> encryptedFlushes = new LinkedList<>();
private Runnable closeTimeoutCanceller = DEFAULT_TIMEOUT_CANCELLER;
SSLChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler, SSLDriver sslDriver,
@ -52,6 +55,10 @@ public final class SSLChannelContext extends SocketChannelContext {
public void register() throws IOException {
super.register();
sslDriver.init();
SSLOutboundBuffer outboundBuffer = sslDriver.getOutboundBuffer();
if (outboundBuffer.hasEncryptedBytesToFlush()) {
encryptedFlushes.addLast(outboundBuffer.buildNetworkFlushOperation());
}
}
@Override
@ -72,34 +79,33 @@ public final class SSLChannelContext extends SocketChannelContext {
return;
}
// If there is currently data in the outbound write buffer, flush the buffer.
if (sslDriver.hasFlushPending()) {
if (pendingChannelFlush()) {
// If the data is not completely flushed, exit. We cannot produce new write data until the
// existing data has been fully flushed.
flushToChannel(sslDriver.getNetworkWriteBuffer());
if (sslDriver.hasFlushPending()) {
flushEncryptedOperation();
if (pendingChannelFlush()) {
return;
}
}
// If the driver is ready for application writes, we can attempt to proceed with any queued writes.
if (sslDriver.readyForApplicationWrites()) {
FlushOperation currentFlush;
while (sslDriver.hasFlushPending() == false && (currentFlush = getPendingFlush()) != null) {
// If the current operation has been fully consumed (encrypted) we now know that it has been
// sent (as we only get to this point if the write buffer has been fully flushed).
if (currentFlush.isFullyFlushed()) {
FlushOperation unencryptedFlush;
while (pendingChannelFlush() == false && (unencryptedFlush = getPendingFlush()) != null) {
if (unencryptedFlush.isFullyFlushed()) {
currentFlushOperationComplete();
} else {
try {
// Attempt to encrypt application write data. The encrypted data ends up in the
// outbound write buffer.
int bytesEncrypted = sslDriver.applicationWrite(currentFlush.getBuffersToWrite());
if (bytesEncrypted == 0) {
sslDriver.write(unencryptedFlush);
SSLOutboundBuffer outboundBuffer = sslDriver.getOutboundBuffer();
if (outboundBuffer.hasEncryptedBytesToFlush() == false) {
break;
}
currentFlush.incrementIndex(bytesEncrypted);
encryptedFlushes.addLast(outboundBuffer.buildNetworkFlushOperation());
// Flush the write buffer to the channel
flushToChannel(sslDriver.getNetworkWriteBuffer());
flushEncryptedOperation();
} catch (IOException e) {
currentFlushOperationFailed(e);
throw e;
@ -109,23 +115,39 @@ public final class SSLChannelContext extends SocketChannelContext {
} else {
// We are not ready for application writes, check if the driver has non-application writes. We
// only want to continue producing new writes if the outbound write buffer is fully flushed.
while (sslDriver.hasFlushPending() == false && sslDriver.needsNonApplicationWrite()) {
while (pendingChannelFlush() == false && sslDriver.needsNonApplicationWrite()) {
sslDriver.nonApplicationWrite();
// If non-application writes were produced, flush the outbound write buffer.
if (sslDriver.hasFlushPending()) {
flushToChannel(sslDriver.getNetworkWriteBuffer());
SSLOutboundBuffer outboundBuffer = sslDriver.getOutboundBuffer();
if (outboundBuffer.hasEncryptedBytesToFlush()) {
encryptedFlushes.addFirst(outboundBuffer.buildNetworkFlushOperation());
flushEncryptedOperation();
}
}
}
}
private void flushEncryptedOperation() throws IOException {
try {
FlushOperation encryptedFlush = encryptedFlushes.getFirst();
flushToChannel(encryptedFlush);
if (encryptedFlush.isFullyFlushed()) {
getSelector().executeListener(encryptedFlush.getListener(), null);
encryptedFlushes.removeFirst();
}
} catch (IOException e) {
getSelector().executeFailedListener(encryptedFlushes.removeFirst().getListener(), e);
throw e;
}
}
@Override
public boolean readyForFlush() {
getSelector().assertOnSelectorThread();
if (sslDriver.readyForApplicationWrites()) {
return sslDriver.hasFlushPending() || super.readyForFlush();
return pendingChannelFlush() || super.readyForFlush();
} else {
return sslDriver.hasFlushPending() || sslDriver.needsNonApplicationWrite();
return pendingChannelFlush() || sslDriver.needsNonApplicationWrite();
}
}
@ -143,13 +165,18 @@ public final class SSLChannelContext extends SocketChannelContext {
sslDriver.read(channelBuffer);
handleReadBytes();
// It is possible that a read call produced non-application bytes to flush
SSLOutboundBuffer outboundBuffer = sslDriver.getOutboundBuffer();
if (outboundBuffer.hasEncryptedBytesToFlush()) {
encryptedFlushes.addLast(outboundBuffer.buildNetworkFlushOperation());
}
return bytesRead;
}
@Override
public boolean selectorShouldClose() {
return closeNow() || sslDriver.isClosed();
return closeNow() || (sslDriver.isClosed() && pendingChannelFlush() == false);
}
@Override
@ -170,6 +197,10 @@ public final class SSLChannelContext extends SocketChannelContext {
getSelector().assertOnSelectorThread();
if (channel.isOpen()) {
closeTimeoutCanceller.run();
for (FlushOperation encryptedFlush : encryptedFlushes) {
getSelector().executeFailedListener(encryptedFlush.getListener(), new ClosedChannelException());
}
encryptedFlushes.clear();
IOUtils.close(super::closeFromSelector, sslDriver::close);
}
}
@ -184,9 +215,14 @@ public final class SSLChannelContext extends SocketChannelContext {
getSelector().queueChannelClose(channel);
}
private boolean pendingChannelFlush() {
return encryptedFlushes.isEmpty() == false;
}
private static class CloseNotifyOperation implements WriteOperation {
private static final BiConsumer<Void, Exception> LISTENER = (v, t) -> {};
private static final BiConsumer<Void, Exception> LISTENER = (v, t) -> {
};
private static final Object WRITE_OBJECT = new Object();
private final SocketChannelContext channelContext;

View File

@ -5,7 +5,9 @@
*/
package org.elasticsearch.xpack.security.transport.nio;
import org.elasticsearch.nio.FlushOperation;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.Page;
import org.elasticsearch.nio.utils.ExceptionsHelper;
import javax.net.ssl.SSLEngine;
@ -29,19 +31,17 @@ import java.util.ArrayList;
* the buffer passed as an argument. Otherwise, it will be consumed internally and advance the SSL/TLS close
* or handshake process.
*
* Producing writes for a channel is more complicated. If there is existing data in the outbound write buffer
* as indicated by {@link #hasFlushPending()}, that data must be written to the channel before more outbound
* data can be produced. If no flushes are pending, {@link #needsNonApplicationWrite()} can be called to
* determine if this driver needs to produce more data to advance the handshake or close process. If that
* method returns true, {@link #nonApplicationWrite()} should be called (and the data produced then flushed
* to the channel) until no further non-application writes are needed.
* Producing writes for a channel is more complicated. The method {@link #needsNonApplicationWrite()} can be
* called to determine if this driver needs to produce more data to advance the handshake or close process.
* If that method returns true, {@link #nonApplicationWrite()} should be called (and the
* data produced then flushed to the channel) until no further non-application writes are needed.
*
* If no non-application writes are needed, {@link #readyForApplicationWrites()} can be called to determine
* if the driver is ready to consume application data. (Note: It is possible that
* {@link #readyForApplicationWrites()} and {@link #needsNonApplicationWrite()} can both return false if the
* driver is waiting on non-application data from the peer.) If the driver indicates it is ready for
* application writes, {@link #applicationWrite(ByteBuffer[])} can be called. This method will encrypt
* application data and place it in the write buffer for flushing to a channel.
* application writes, {@link #write(FlushOperation)} can be called. This method will
* encrypt flush operation application data and place it in the outbound buffer for flushing to a channel.
*
* If you are ready to close the channel {@link #initiateClose()} should be called. After that is called, the
* driver will start producing non-application writes related to notifying the peer connection that this
@ -50,23 +50,25 @@ import java.util.ArrayList;
*/
public class SSLDriver implements AutoCloseable {
private static final ByteBuffer[] EMPTY_BUFFER_ARRAY = new ByteBuffer[0];
private static final ByteBuffer[] EMPTY_BUFFERS = {ByteBuffer.allocate(0)};
private static final FlushOperation EMPTY_FLUSH_OPERATION = new FlushOperation(EMPTY_BUFFERS, (r, t) -> {});
private final SSLEngine engine;
// TODO: When the bytes are actually recycled, we need to test that they are released on driver close
private final SSLOutboundBuffer outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n)));
private final boolean isClientMode;
// This should only be accessed by the network thread associated with this channel, so nothing needs to
// be volatile.
private Mode currentMode = new HandshakeMode();
private ByteBuffer networkWriteBuffer;
private ByteBuffer networkReadBuffer;
private int packetSize;
public SSLDriver(SSLEngine engine, boolean isClientMode) {
this.engine = engine;
this.isClientMode = isClientMode;
SSLSession session = engine.getSession();
this.networkReadBuffer = ByteBuffer.allocate(session.getPacketBufferSize());
this.networkWriteBuffer = ByteBuffer.allocate(session.getPacketBufferSize());
this.networkWriteBuffer.position(this.networkWriteBuffer.limit());
packetSize = session.getPacketBufferSize();
this.networkReadBuffer = ByteBuffer.allocate(packetSize);
}
public void init() throws SSLException {
@ -100,22 +102,18 @@ public class SSLDriver implements AutoCloseable {
return engine;
}
public boolean hasFlushPending() {
return networkWriteBuffer.hasRemaining();
}
public boolean isHandshaking() {
return currentMode.isHandshake();
}
public ByteBuffer getNetworkWriteBuffer() {
return networkWriteBuffer;
}
public ByteBuffer getNetworkReadBuffer() {
return networkReadBuffer;
}
public SSLOutboundBuffer getOutboundBuffer() {
return outboundBuffer;
}
public void read(InboundChannelBuffer buffer) throws SSLException {
Mode modePriorToRead;
do {
@ -134,15 +132,14 @@ public class SSLDriver implements AutoCloseable {
return currentMode.needsNonApplicationWrite();
}
public int applicationWrite(ByteBuffer[] buffers) throws SSLException {
assert readyForApplicationWrites() : "Should not be called if driver is not ready for application writes";
return currentMode.write(buffers);
public int write(FlushOperation applicationBytes) throws SSLException {
return currentMode.write(applicationBytes);
}
public void nonApplicationWrite() throws SSLException {
assert currentMode.isApplication() == false : "Should not be called if driver is in application mode";
if (currentMode.isApplication() == false) {
currentMode.write(EMPTY_BUFFER_ARRAY);
currentMode.write(EMPTY_FLUSH_OPERATION);
} else {
throw new AssertionError("Attempted to non-application write from invalid mode: " + currentMode.modeName());
}
@ -158,6 +155,7 @@ public class SSLDriver implements AutoCloseable {
@Override
public void close() throws SSLException {
outboundBuffer.close();
ArrayList<SSLException> closingExceptions = new ArrayList<>(2);
closingInternal();
CloseMode closeMode = (CloseMode) this.currentMode;
@ -205,45 +203,36 @@ public class SSLDriver implements AutoCloseable {
}
}
private SSLEngineResult wrap(ByteBuffer[] buffers) throws SSLException {
assert hasFlushPending() == false : "Should never called with pending writes";
private SSLEngineResult wrap(SSLOutboundBuffer outboundBuffer) throws SSLException {
return wrap(outboundBuffer, EMPTY_FLUSH_OPERATION);
}
networkWriteBuffer.clear();
private SSLEngineResult wrap(SSLOutboundBuffer outboundBuffer, FlushOperation applicationBytes) throws SSLException {
ByteBuffer[] buffers = applicationBytes.getBuffersToWrite(engine.getSession().getApplicationBufferSize());
while (true) {
SSLEngineResult result;
ByteBuffer networkBuffer = outboundBuffer.nextWriteBuffer(packetSize);
try {
if (buffers.length == 1) {
result = engine.wrap(buffers[0], networkWriteBuffer);
} else {
result = engine.wrap(buffers, networkWriteBuffer);
}
result = engine.wrap(buffers, networkBuffer);
} catch (SSLException e) {
networkWriteBuffer.position(networkWriteBuffer.limit());
outboundBuffer.incrementEncryptedBytes(0);
throw e;
}
outboundBuffer.incrementEncryptedBytes(result.bytesProduced());
applicationBytes.incrementIndex(result.bytesConsumed());
switch (result.getStatus()) {
case OK:
networkWriteBuffer.flip();
return result;
case BUFFER_UNDERFLOW:
throw new IllegalStateException("Should not receive BUFFER_UNDERFLOW on WRAP");
case BUFFER_OVERFLOW:
// There is not enough space in the network buffer for an entire SSL packet. Expand the
// buffer if it's smaller than the current session packet size. Otherwise return and wait
// for existing data to be flushed.
int currentCapacity = networkWriteBuffer.capacity();
ensureNetworkWriteBufferSize();
if (currentCapacity == networkWriteBuffer.capacity()) {
return result;
}
packetSize = engine.getSession().getPacketBufferSize();
// There is not enough space in the network buffer for an entire SSL packet. We will
// allocate a buffer with the correct packet size the next time through the loop.
break;
case CLOSED:
if (result.bytesProduced() > 0) {
networkWriteBuffer.flip();
} else {
assert false : "WRAP during close processing should produce close message.";
}
assert result.bytesProduced() > 0 : "WRAP during close processing should produce close message.";
return result;
default:
throw new IllegalStateException("Unexpected WRAP result: " + result.getStatus());
@ -265,23 +254,12 @@ public class SSLDriver implements AutoCloseable {
}
}
private void ensureNetworkWriteBufferSize() {
networkWriteBuffer = ensureNetBufferSize(networkWriteBuffer);
}
private void ensureNetworkReadBufferSize() {
networkReadBuffer = ensureNetBufferSize(networkReadBuffer);
}
private ByteBuffer ensureNetBufferSize(ByteBuffer current) {
int networkPacketSize = engine.getSession().getPacketBufferSize();
if (current.capacity() < networkPacketSize) {
ByteBuffer newBuffer = ByteBuffer.allocate(networkPacketSize);
current.flip();
newBuffer.put(current);
return newBuffer;
} else {
return current;
packetSize = engine.getSession().getPacketBufferSize();
if (networkReadBuffer.capacity() < packetSize) {
ByteBuffer newBuffer = ByteBuffer.allocate(packetSize);
networkReadBuffer.flip();
newBuffer.put(networkReadBuffer);
}
}
@ -306,7 +284,7 @@ public class SSLDriver implements AutoCloseable {
void read(InboundChannelBuffer buffer) throws SSLException;
int write(ByteBuffer[] buffers) throws SSLException;
int write(FlushOperation applicationBytes) throws SSLException;
boolean needsNonApplicationWrite();
@ -326,8 +304,7 @@ public class SSLDriver implements AutoCloseable {
private void startHandshake() throws SSLException {
handshakeStatus = engine.getHandshakeStatus();
if (handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP &&
handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_WRAP) {
if (handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
try {
handshake();
} catch (SSLException e) {
@ -346,13 +323,7 @@ public class SSLDriver implements AutoCloseable {
continueHandshaking = false;
break;
case NEED_WRAP:
if (hasFlushPending() == false) {
handshakeStatus = wrap(EMPTY_BUFFER_ARRAY).getHandshakeStatus();
}
// If we need NEED_TASK we should run the tasks immediately
if (handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_TASK) {
continueHandshaking = false;
}
handshakeStatus = wrap(outboundBuffer).getHandshakeStatus();
break;
case NEED_TASK:
runTasks();
@ -390,7 +361,7 @@ public class SSLDriver implements AutoCloseable {
}
@Override
public int write(ByteBuffer[] buffers) throws SSLException {
public int write(FlushOperation applicationBytes) throws SSLException {
try {
handshake();
} catch (SSLException e) {
@ -445,8 +416,7 @@ public class SSLDriver implements AutoCloseable {
String message = "Expected to be in handshaking/closed mode. Instead in application mode.";
throw new AssertionError(message);
}
} else if (hasFlushPending() == false) {
// We only acknowledge that we are done handshaking if there are no bytes that need to be written
} else {
if (currentMode.isHandshake()) {
currentMode = new ApplicationMode();
} else {
@ -473,10 +443,17 @@ public class SSLDriver implements AutoCloseable {
}
@Override
public int write(ByteBuffer[] buffers) throws SSLException {
SSLEngineResult result = wrap(buffers);
maybeRenegotiation(result.getHandshakeStatus());
return result.bytesConsumed();
public int write(FlushOperation applicationBytes) throws SSLException {
boolean continueWrap = true;
int totalBytesProduced = 0;
while (continueWrap && applicationBytes.isFullyFlushed() == false) {
SSLEngineResult result = wrap(outboundBuffer, applicationBytes);
int bytesProduced = result.bytesProduced();
totalBytesProduced += bytesProduced;
boolean renegotiationRequested = maybeRenegotiation(result.getHandshakeStatus());
continueWrap = bytesProduced > 0 && renegotiationRequested == false;
}
return totalBytesProduced;
}
private boolean maybeRenegotiation(SSLEngineResult.HandshakeStatus newStatus) throws SSLException {
@ -560,18 +537,21 @@ public class SSLDriver implements AutoCloseable {
}
@Override
public int write(ByteBuffer[] buffers) throws SSLException {
if (hasFlushPending() == false && engine.isOutboundDone()) {
needToSendClose = false;
// Close inbound if it is still open and we have decided not to wait for response.
if (needToReceiveClose == false && engine.isInboundDone() == false) {
closeInboundAndSwallowPeerDidNotCloseException();
public int write(FlushOperation applicationBytes) throws SSLException {
int bytesProduced = 0;
if (engine.isOutboundDone() == false) {
bytesProduced += wrap(outboundBuffer).bytesProduced();
if (engine.isOutboundDone()) {
needToSendClose = false;
// Close inbound if it is still open and we have decided not to wait for response.
if (needToReceiveClose == false && engine.isInboundDone() == false) {
closeInboundAndSwallowPeerDidNotCloseException();
}
}
} else {
wrap(EMPTY_BUFFER_ARRAY);
assert hasFlushPending() : "Should have produced close message";
needToSendClose = false;
}
return 0;
return bytesProduced;
}
@Override

View File

@ -0,0 +1,68 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.security.transport.nio;
import org.elasticsearch.core.internal.io.IOUtils;
import org.elasticsearch.nio.FlushOperation;
import org.elasticsearch.nio.Page;
import java.nio.ByteBuffer;
import java.util.ArrayDeque;
import java.util.function.IntFunction;
public class SSLOutboundBuffer implements AutoCloseable {
private final ArrayDeque<Page> pages;
private final IntFunction<Page> pageSupplier;
private Page currentPage;
SSLOutboundBuffer(IntFunction<Page> pageSupplier) {
this.pages = new ArrayDeque<>();
this.pageSupplier = pageSupplier;
}
void incrementEncryptedBytes(int encryptedBytesProduced) {
if (encryptedBytesProduced != 0) {
currentPage.byteBuffer().limit(encryptedBytesProduced);
pages.addLast(currentPage);
}
currentPage = null;
}
ByteBuffer nextWriteBuffer(int networkBufferSize) {
if (currentPage != null) {
// If there is an existing page, close it as it wasn't large enough to accommodate the SSLEngine.
currentPage.close();
}
Page newPage = pageSupplier.apply(networkBufferSize);
currentPage = newPage;
return newPage.byteBuffer().duplicate();
}
FlushOperation buildNetworkFlushOperation() {
int pageCount = pages.size();
ByteBuffer[] byteBuffers = new ByteBuffer[pageCount];
Page[] pagesToClose = new Page[pageCount];
for (int i = 0; i < pageCount; ++i) {
Page page = pages.removeFirst();
pagesToClose[i] = page;
byteBuffers[i] = page.byteBuffer();
}
return new FlushOperation(byteBuffers, (r, e) -> IOUtils.closeWhileHandlingException(pagesToClose));
}
boolean hasEncryptedBytesToFlush() {
return pages.isEmpty() == false;
}
@Override
public void close() {
IOUtils.closeWhileHandlingException(pages);
}
}

View File

@ -22,6 +22,7 @@ import org.elasticsearch.nio.ChannelFactory;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.Page;
import org.elasticsearch.nio.ServerChannelContext;
import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.threadpool.ThreadPool;
@ -92,9 +93,9 @@ public class SecurityNioHttpServerTransport extends NioHttpServerTransport {
@Override
public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException {
NioHttpChannel httpChannel = new NioHttpChannel(channel);
Supplier<InboundChannelBuffer.Page> pageSupplier = () -> {
Supplier<Page> pageSupplier = () -> {
Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close);
return new Page(ByteBuffer.wrap(bytes.v()), bytes::close);
};
HttpReadWriteHandler httpHandler = new HttpReadWriteHandler(httpChannel,SecurityNioHttpServerTransport.this,
handlingSettings, corsConfig);

View File

@ -21,6 +21,7 @@ import org.elasticsearch.nio.ChannelFactory;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.Page;
import org.elasticsearch.nio.ServerChannelContext;
import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.threadpool.ThreadPool;
@ -155,9 +156,9 @@ public class SecurityNioTransport extends NioTransport {
@Override
public NioTcpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException {
NioTcpChannel nioChannel = new NioTcpChannel(isClient == false, profileName, channel);
Supplier<InboundChannelBuffer.Page> pageSupplier = () -> {
Supplier<Page> pageSupplier = () -> {
Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close);
return new Page(ByteBuffer.wrap(bytes.v()), bytes::close);
};
TcpReadWriteHandler readWriteHandler = new TcpReadWriteHandler(nioChannel, SecurityNioTransport.this);
InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier);

View File

@ -8,10 +8,12 @@ package org.elasticsearch.xpack.security.transport.nio;
import org.elasticsearch.common.CheckedFunction;
import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.nio.BytesWriteHandler;
import org.elasticsearch.nio.FlushOperation;
import org.elasticsearch.nio.FlushReadyWrite;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.Page;
import org.elasticsearch.nio.TaskScheduler;
import org.elasticsearch.nio.WriteOperation;
import org.elasticsearch.test.ESTestCase;
@ -28,6 +30,7 @@ import java.util.function.Consumer;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyLong;
import static org.mockito.Matchers.eq;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
@ -43,13 +46,13 @@ public class SSLChannelContextTests extends ESTestCase {
private SocketChannel rawChannel;
private SSLChannelContext context;
private InboundChannelBuffer channelBuffer;
private SSLOutboundBuffer outboundBuffer;
private NioSelector selector;
private TaskScheduler nioTimer;
private BiConsumer<Void, Exception> listener;
private Consumer exceptionHandler;
private SSLDriver sslDriver;
private ByteBuffer readBuffer = ByteBuffer.allocate(1 << 14);
private ByteBuffer writeBuffer = ByteBuffer.allocate(1 << 14);
private int messageLength;
@Before
@ -66,6 +69,7 @@ public class SSLChannelContextTests extends ESTestCase {
rawChannel = mock(SocketChannel.class);
sslDriver = mock(SSLDriver.class);
channelBuffer = InboundChannelBuffer.allocatingInstance();
outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n), () -> {}));
when(channel.getRawChannel()).thenReturn(rawChannel);
exceptionHandler = mock(Consumer.class);
context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readWriteHandler, channelBuffer);
@ -73,7 +77,7 @@ public class SSLChannelContextTests extends ESTestCase {
when(selector.isOnCurrentThread()).thenReturn(true);
when(selector.getTaskScheduler()).thenReturn(nioTimer);
when(sslDriver.getNetworkReadBuffer()).thenReturn(readBuffer);
when(sslDriver.getNetworkWriteBuffer()).thenReturn(writeBuffer);
when(sslDriver.getOutboundBuffer()).thenReturn(outboundBuffer);
ByteBuffer buffer = ByteBuffer.allocate(1 << 14);
when(selector.getIoBuffer()).thenAnswer(invocationOnMock -> {
buffer.clear();
@ -85,7 +89,7 @@ public class SSLChannelContextTests extends ESTestCase {
byte[] bytes = createMessage(messageLength);
when(rawChannel.read(any(ByteBuffer.class))).thenReturn(bytes.length);
doAnswer(getAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer);
doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer);
when(readConsumer.apply(channelBuffer)).thenReturn(messageLength, 0);
@ -100,7 +104,7 @@ public class SSLChannelContextTests extends ESTestCase {
byte[] bytes = createMessage(messageLength * 2);
when(rawChannel.read(any(ByteBuffer.class))).thenReturn(bytes.length);
doAnswer(getAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer);
doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer);
when(readConsumer.apply(channelBuffer)).thenReturn(messageLength, messageLength, 0);
@ -115,7 +119,7 @@ public class SSLChannelContextTests extends ESTestCase {
byte[] bytes = createMessage(messageLength);
when(rawChannel.read(any(ByteBuffer.class))).thenReturn(bytes.length);
doAnswer(getAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer);
doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer);
when(readConsumer.apply(channelBuffer)).thenReturn(0);
@ -173,7 +177,6 @@ public class SSLChannelContextTests extends ESTestCase {
public void testQueuedWritesAreIgnoredWhenNotReadyForAppWrites() {
when(sslDriver.readyForApplicationWrites()).thenReturn(false);
when(sslDriver.hasFlushPending()).thenReturn(false);
when(sslDriver.needsNonApplicationWrite()).thenReturn(false);
context.queueWriteOperation(mock(FlushReadyWrite.class));
@ -181,25 +184,25 @@ public class SSLChannelContextTests extends ESTestCase {
assertFalse(context.readyForFlush());
}
public void testPendingFlushMeansWriteInterested() {
when(sslDriver.readyForApplicationWrites()).thenReturn(randomBoolean());
when(sslDriver.hasFlushPending()).thenReturn(true);
when(sslDriver.needsNonApplicationWrite()).thenReturn(false);
public void testPendingEncryptedFlushMeansWriteInterested() throws Exception {
when(sslDriver.readyForApplicationWrites()).thenReturn(false);
when(sslDriver.needsNonApplicationWrite()).thenReturn(true, false);
doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite();
// Call will put bytes in buffer to flush
context.flushChannel();
assertTrue(context.readyForFlush());
}
public void testNeedsNonAppWritesMeansWriteInterested() {
when(sslDriver.readyForApplicationWrites()).thenReturn(false);
when(sslDriver.hasFlushPending()).thenReturn(false);
when(sslDriver.needsNonApplicationWrite()).thenReturn(true);
assertTrue(context.readyForFlush());
}
public void testNotWritesInterestInAppMode() {
public void testNoNonAppWriteInterestInAppMode() {
when(sslDriver.readyForApplicationWrites()).thenReturn(true);
when(sslDriver.hasFlushPending()).thenReturn(false);
assertFalse(context.readyForFlush());
@ -207,18 +210,25 @@ public class SSLChannelContextTests extends ESTestCase {
}
public void testFirstFlushMustFinishForWriteToContinue() throws Exception {
when(sslDriver.hasFlushPending()).thenReturn(true, true);
when(sslDriver.readyForApplicationWrites()).thenReturn(false);
when(sslDriver.needsNonApplicationWrite()).thenReturn(true);
doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite();
// First call will put bytes in buffer to flush
context.flushChannel();
assertTrue(context.readyForFlush());
// Second call will will not continue generating non-app bytes because they still need to be flushed
context.flushChannel();
assertTrue(context.readyForFlush());
verify(sslDriver, times(0)).nonApplicationWrite();
verify(sslDriver, times(1)).nonApplicationWrite();
}
public void testNonAppWrites() throws Exception {
when(sslDriver.hasFlushPending()).thenReturn(false, false, true, false, true);
when(sslDriver.needsNonApplicationWrite()).thenReturn(true, true, false);
when(sslDriver.readyForApplicationWrites()).thenReturn(false);
doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite();
when(rawChannel.write(same(selector.getIoBuffer()))).thenReturn(1);
context.flushChannel();
@ -227,9 +237,10 @@ public class SSLChannelContextTests extends ESTestCase {
}
public void testNonAppWritesStopIfBufferNotFullyFlushed() throws Exception {
when(sslDriver.hasFlushPending()).thenReturn(false, false, true, true);
when(sslDriver.needsNonApplicationWrite()).thenReturn(true, true, true, true);
when(sslDriver.needsNonApplicationWrite()).thenReturn(true);
when(sslDriver.readyForApplicationWrites()).thenReturn(false);
doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite();
when(rawChannel.write(same(selector.getIoBuffer()))).thenReturn(0);
context.flushChannel();
@ -239,34 +250,28 @@ public class SSLChannelContextTests extends ESTestCase {
public void testQueuedWriteIsFlushedInFlushCall() throws Exception {
ByteBuffer[] buffers = {ByteBuffer.allocate(10)};
FlushReadyWrite flushOperation = mock(FlushReadyWrite.class);
FlushReadyWrite flushOperation = new FlushReadyWrite(context, buffers, listener);
context.queueWriteOperation(flushOperation);
when(flushOperation.getBuffersToWrite()).thenReturn(buffers);
when(flushOperation.getListener()).thenReturn(listener);
when(sslDriver.hasFlushPending()).thenReturn(false, false, false, false);
when(sslDriver.readyForApplicationWrites()).thenReturn(true);
when(sslDriver.applicationWrite(buffers)).thenReturn(10);
when(flushOperation.isFullyFlushed()).thenReturn(false,true);
doAnswer(getWriteAnswer(10, true)).when(sslDriver).write(eq(flushOperation));
when(rawChannel.write(same(selector.getIoBuffer()))).thenReturn(10);
context.flushChannel();
verify(flushOperation).incrementIndex(10);
verify(rawChannel, times(1)).write(same(selector.getIoBuffer()));
verify(selector).executeListener(listener, null);
assertFalse(context.readyForFlush());
}
public void testPartialFlush() throws IOException {
ByteBuffer[] buffers = {ByteBuffer.allocate(10)};
FlushReadyWrite flushOperation = mock(FlushReadyWrite.class);
ByteBuffer[] buffers = {ByteBuffer.allocate(5)};
FlushReadyWrite flushOperation = new FlushReadyWrite(context, buffers, listener);
context.queueWriteOperation(flushOperation);
when(flushOperation.getBuffersToWrite()).thenReturn(buffers);
when(flushOperation.getListener()).thenReturn(listener);
when(sslDriver.hasFlushPending()).thenReturn(false, false, true);
when(sslDriver.readyForApplicationWrites()).thenReturn(true);
when(sslDriver.applicationWrite(buffers)).thenReturn(5);
when(flushOperation.isFullyFlushed()).thenReturn(false, false);
doAnswer(getWriteAnswer(5, true)).when(sslDriver).write(eq(flushOperation));
when(rawChannel.write(same(selector.getIoBuffer()))).thenReturn(4);
context.flushChannel();
verify(rawChannel, times(1)).write(same(selector.getIoBuffer()));
@ -279,24 +284,16 @@ public class SSLChannelContextTests extends ESTestCase {
BiConsumer<Void, Exception> listener2 = mock(BiConsumer.class);
ByteBuffer[] buffers1 = {ByteBuffer.allocate(10)};
ByteBuffer[] buffers2 = {ByteBuffer.allocate(5)};
FlushReadyWrite flushOperation1 = mock(FlushReadyWrite.class);
FlushReadyWrite flushOperation2 = mock(FlushReadyWrite.class);
when(flushOperation1.getBuffersToWrite()).thenReturn(buffers1);
when(flushOperation2.getBuffersToWrite()).thenReturn(buffers2);
when(flushOperation1.getListener()).thenReturn(listener);
when(flushOperation2.getListener()).thenReturn(listener2);
FlushReadyWrite flushOperation1 = new FlushReadyWrite(context, buffers1, listener);
FlushReadyWrite flushOperation2 = new FlushReadyWrite(context, buffers2, listener2);
context.queueWriteOperation(flushOperation1);
context.queueWriteOperation(flushOperation2);
when(sslDriver.hasFlushPending()).thenReturn(false, false, false, false, false, true);
when(sslDriver.readyForApplicationWrites()).thenReturn(true);
when(sslDriver.applicationWrite(buffers1)).thenReturn(5, 5);
when(sslDriver.applicationWrite(buffers2)).thenReturn(3);
when(flushOperation1.isFullyFlushed()).thenReturn(false, false, true);
when(flushOperation2.isFullyFlushed()).thenReturn(false);
doAnswer(getWriteAnswer(5, true)).when(sslDriver).write(any(FlushOperation.class));
when(rawChannel.write(same(selector.getIoBuffer()))).thenReturn(5, 5, 2);
context.flushChannel();
verify(flushOperation1, times(2)).incrementIndex(5);
verify(rawChannel, times(3)).write(same(selector.getIoBuffer()));
verify(selector).executeListener(listener, null);
verify(selector, times(0)).executeListener(listener2, null);
@ -304,29 +301,27 @@ public class SSLChannelContextTests extends ESTestCase {
}
public void testWhenIOExceptionThrownListenerIsCalled() throws IOException {
ByteBuffer[] buffers = {ByteBuffer.allocate(10)};
FlushReadyWrite flushOperation = mock(FlushReadyWrite.class);
ByteBuffer[] buffers = {ByteBuffer.allocate(5)};
FlushReadyWrite flushOperation = new FlushReadyWrite(context, buffers, listener);
context.queueWriteOperation(flushOperation);
IOException exception = new IOException();
when(flushOperation.getBuffersToWrite()).thenReturn(buffers);
when(flushOperation.getListener()).thenReturn(listener);
when(sslDriver.hasFlushPending()).thenReturn(false, false);
when(sslDriver.readyForApplicationWrites()).thenReturn(true);
when(sslDriver.applicationWrite(buffers)).thenReturn(5);
doAnswer(getWriteAnswer(5, true)).when(sslDriver).write(eq(flushOperation));
when(rawChannel.write(any(ByteBuffer.class))).thenThrow(exception);
when(flushOperation.isFullyFlushed()).thenReturn(false);
expectThrows(IOException.class, () -> context.flushChannel());
verify(flushOperation).incrementIndex(5);
verify(selector).executeFailedListener(listener, exception);
assertFalse(context.readyForFlush());
}
public void testWriteIOExceptionMeansChannelReadyToClose() throws Exception {
when(sslDriver.hasFlushPending()).thenReturn(true);
when(sslDriver.needsNonApplicationWrite()).thenReturn(true);
when(sslDriver.readyForApplicationWrites()).thenReturn(false);
when(sslDriver.needsNonApplicationWrite()).thenReturn(true);
doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite();
context.flushChannel();
when(rawChannel.write(any(ByteBuffer.class))).thenThrow(new IOException());
assertFalse(context.selectorShouldClose());
@ -413,7 +408,21 @@ public class SSLChannelContextTests extends ESTestCase {
}
}
private Answer getAnswerForBytes(byte[] bytes) {
private Answer<Integer> getWriteAnswer(int bytesToEncrypt, boolean isApp) {
return invocationOnMock -> {
ByteBuffer byteBuffer = outboundBuffer.nextWriteBuffer(bytesToEncrypt + 1);
for (int i = 0; i < bytesToEncrypt; ++i) {
byteBuffer.put((byte) i);
}
outboundBuffer.incrementEncryptedBytes(bytesToEncrypt);
if (isApp) {
((FlushOperation) invocationOnMock.getArguments()[0]).incrementIndex(bytesToEncrypt);
}
return bytesToEncrypt;
};
}
private Answer getReadAnswerForBytes(byte[] bytes) {
return invocationOnMock -> {
InboundChannelBuffer buffer = (InboundChannelBuffer) invocationOnMock.getArguments()[0];
buffer.ensureCapacity(buffer.getIndex() + bytes.length);

View File

@ -6,7 +6,9 @@
package org.elasticsearch.xpack.security.transport.nio;
import org.elasticsearch.bootstrap.JavaVersion;
import org.elasticsearch.nio.FlushOperation;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.Page;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ssl.CertParsingUtils;
import org.elasticsearch.xpack.core.ssl.PemUtils;
@ -28,8 +30,7 @@ import java.util.function.Supplier;
public class SSLDriverTests extends ESTestCase {
private final Supplier<InboundChannelBuffer.Page> pageSupplier =
() -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14), () -> {});
private final Supplier<Page> pageSupplier = () -> new Page(ByteBuffer.allocate(1 << 14), () -> {});
private InboundChannelBuffer serverBuffer = new InboundChannelBuffer(pageSupplier);
private InboundChannelBuffer clientBuffer = new InboundChannelBuffer(pageSupplier);
private InboundChannelBuffer genericBuffer = new InboundChannelBuffer(pageSupplier);
@ -141,10 +142,6 @@ public class SSLDriverTests extends ESTestCase {
boolean expectedMessage = oldExpected.equals(sslException.getMessage()) || jdk11Expected.equals(sslException.getMessage());
assertTrue("Unexpected exception message: " + sslException.getMessage(), expectedMessage);
// In JDK11 we need an non-application write
if (serverDriver.needsNonApplicationWrite()) {
serverDriver.nonApplicationWrite();
}
// Prior to JDK11 we still need to send a close alert
if (serverDriver.isClosed() == false) {
failedCloseAlert(serverDriver, clientDriver, Arrays.asList("Received fatal alert: protocol_version",
@ -166,10 +163,7 @@ public class SSLDriverTests extends ESTestCase {
SSLDriver serverDriver = getDriver(serverEngine, false);
expectThrows(SSLException.class, () -> handshake(clientDriver, serverDriver));
// In JDK11 we need an non-application write
if (serverDriver.needsNonApplicationWrite()) {
serverDriver.nonApplicationWrite();
}
// Prior to JDK11 we still need to send a close alert
if (serverDriver.isClosed() == false) {
List<String> messages = Arrays.asList("Received fatal alert: handshake_failure",
@ -187,13 +181,11 @@ public class SSLDriverTests extends ESTestCase {
clientDriver.init();
serverDriver.init();
assertTrue(clientDriver.needsNonApplicationWrite());
assertTrue(clientDriver.getOutboundBuffer().hasEncryptedBytesToFlush());
assertFalse(serverDriver.needsNonApplicationWrite());
sendHandshakeMessages(clientDriver, serverDriver);
sendHandshakeMessages(serverDriver, clientDriver);
sendData(clientDriver, serverDriver);
assertTrue(clientDriver.isHandshaking());
assertTrue(serverDriver.isHandshaking());
@ -222,13 +214,11 @@ public class SSLDriverTests extends ESTestCase {
clientDriver.init();
serverDriver.init();
assertTrue(clientDriver.needsNonApplicationWrite());
assertTrue(clientDriver.getOutboundBuffer().hasEncryptedBytesToFlush());
assertFalse(serverDriver.needsNonApplicationWrite());
sendHandshakeMessages(clientDriver, serverDriver);
sendHandshakeMessages(serverDriver, clientDriver);
sendData(clientDriver, serverDriver);
assertTrue(clientDriver.isHandshaking());
assertTrue(serverDriver.isHandshaking());
@ -239,9 +229,6 @@ public class SSLDriverTests extends ESTestCase {
sendNonApplicationWrites(serverDriver, clientDriver);
// We are immediately fully closed due to SSLEngine inconsistency
assertTrue(serverDriver.isClosed());
// This should not throw exception yet as the SSLEngine will not UNWRAP data while attempting to WRAP
clientDriver.read(clientBuffer);
sendNonApplicationWrites(clientDriver, serverDriver);
SSLException sslException = expectThrows(SSLException.class, () -> clientDriver.read(clientBuffer));
assertEquals("Received close_notify during handshake", sslException.getMessage());
assertTrue(clientDriver.needsNonApplicationWrite());
@ -306,13 +293,13 @@ public class SSLDriverTests extends ESTestCase {
}
private void sendNonApplicationWrites(SSLDriver sendDriver, SSLDriver receiveDriver) throws SSLException {
while (sendDriver.needsNonApplicationWrite() || sendDriver.hasFlushPending()) {
if (sendDriver.hasFlushPending() == false) {
SSLOutboundBuffer outboundBuffer = sendDriver.getOutboundBuffer();
while (sendDriver.needsNonApplicationWrite() || outboundBuffer.hasEncryptedBytesToFlush()) {
if (outboundBuffer.hasEncryptedBytesToFlush()) {
sendData(outboundBuffer.buildNetworkFlushOperation(), receiveDriver);
} else {
sendDriver.nonApplicationWrite();
}
if (sendDriver.hasFlushPending()) {
sendData(sendDriver, receiveDriver, true);
}
}
}
@ -326,7 +313,7 @@ public class SSLDriverTests extends ESTestCase {
serverDriver.init();
}
assertTrue(clientDriver.needsNonApplicationWrite() || clientDriver.hasFlushPending());
assertTrue(clientDriver.getOutboundBuffer().hasEncryptedBytesToFlush());
assertFalse(serverDriver.needsNonApplicationWrite());
sendHandshakeMessages(clientDriver, serverDriver);
@ -341,7 +328,6 @@ public class SSLDriverTests extends ESTestCase {
sendHandshakeMessages(clientDriver, serverDriver);
assertTrue(clientDriver.isHandshaking());
assertTrue(serverDriver.isHandshaking());
sendHandshakeMessages(serverDriver, clientDriver);
@ -350,58 +336,51 @@ public class SSLDriverTests extends ESTestCase {
}
private void sendHandshakeMessages(SSLDriver sendDriver, SSLDriver receiveDriver) throws IOException {
assertTrue(sendDriver.needsNonApplicationWrite() || sendDriver.hasFlushPending());
assertTrue(sendDriver.needsNonApplicationWrite() || sendDriver.getOutboundBuffer().hasEncryptedBytesToFlush());
while (sendDriver.needsNonApplicationWrite() || sendDriver.hasFlushPending()) {
if (sendDriver.hasFlushPending() == false) {
sendDriver.nonApplicationWrite();
}
if (sendDriver.isHandshaking()) {
assertTrue(sendDriver.hasFlushPending());
sendData(sendDriver, receiveDriver);
assertFalse(sendDriver.hasFlushPending());
SSLOutboundBuffer outboundBuffer = sendDriver.getOutboundBuffer();
while (sendDriver.needsNonApplicationWrite() || outboundBuffer.hasEncryptedBytesToFlush()) {
if (outboundBuffer.hasEncryptedBytesToFlush()) {
sendData(outboundBuffer.buildNetworkFlushOperation(), receiveDriver);
receiveDriver.read(genericBuffer);
} else {
sendDriver.nonApplicationWrite();
}
}
if (receiveDriver.isHandshaking()) {
assertTrue(receiveDriver.needsNonApplicationWrite() || receiveDriver.hasFlushPending());
assertTrue(receiveDriver.needsNonApplicationWrite() || receiveDriver.getOutboundBuffer().hasEncryptedBytesToFlush());
}
}
private void sendAppData(SSLDriver sendDriver, SSLDriver receiveDriver, ByteBuffer[] message) throws IOException {
assertFalse(sendDriver.needsNonApplicationWrite());
int bytesToEncrypt = Arrays.stream(message).mapToInt(Buffer::remaining).sum();
SSLOutboundBuffer outboundBuffer = sendDriver.getOutboundBuffer();
FlushOperation flushOperation = new FlushOperation(message, (r, l) -> {});
int bytesEncrypted = 0;
while (bytesToEncrypt > bytesEncrypted) {
bytesEncrypted += sendDriver.applicationWrite(message);
sendData(sendDriver, receiveDriver);
bytesEncrypted += sendDriver.write(flushOperation);
sendData(outboundBuffer.buildNetworkFlushOperation(), receiveDriver);
}
}
private void sendData(SSLDriver sendDriver, SSLDriver receiveDriver) {
sendData(sendDriver, receiveDriver, randomBoolean());
}
private void sendData(SSLDriver sendDriver, SSLDriver receiveDriver, boolean partial) {
ByteBuffer writeBuffer = sendDriver.getNetworkWriteBuffer();
private void sendData(FlushOperation flushOperation, SSLDriver receiveDriver) {
ByteBuffer readBuffer = receiveDriver.getNetworkReadBuffer();
if (partial) {
int initialLimit = writeBuffer.limit();
int bytesToWrite = writeBuffer.remaining() / (randomInt(2) + 2);
writeBuffer.limit(writeBuffer.position() + bytesToWrite);
readBuffer.put(writeBuffer);
writeBuffer.limit(initialLimit);
assertTrue(sendDriver.hasFlushPending());
readBuffer.put(writeBuffer);
assertFalse(sendDriver.hasFlushPending());
ByteBuffer[] writeBuffers = flushOperation.getBuffersToWrite();
int bytesToEncrypt = Arrays.stream(writeBuffers).mapToInt(Buffer::remaining).sum();
assert bytesToEncrypt < readBuffer.capacity() : "Flush operation must be less that read buffer";
assert writeBuffers.length > 0 : "No write buffers";
} else {
for (ByteBuffer writeBuffer : writeBuffers) {
int written = writeBuffer.remaining();
readBuffer.put(writeBuffer);
assertFalse(sendDriver.hasFlushPending());
flushOperation.incrementIndex(written);
}
assertTrue(flushOperation.isFullyFlushed());
}
private SSLDriver getDriver(SSLEngine engine, boolean isClient) {