Read multiple TLS packets in one read call (#41820)

This is related to #27260. Currently we have a single read buffer that
is no larger than a single TLS packet. This prevents us from reading
multiple TLS packets in a single socket read call. This commit modifies
our TLS work to support reading similar to the plaintext case. The data
will be copied to a (potentially) recycled TLS packet-sized buffer for
interaction with the SSLEngine.
This commit is contained in:
Tim Brooks 2019-05-06 09:51:32 -06:00 committed by GitHub
parent 228d23de6d
commit 927013426a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 407 additions and 374 deletions

View File

@ -27,7 +27,7 @@ import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier;
import java.util.function.IntFunction;
/**
* This is a channel byte buffer composed internally of 16kb pages. When an entire message has been read
@ -37,15 +37,14 @@ import java.util.function.Supplier;
*/
public final class InboundChannelBuffer implements AutoCloseable {
private static final int PAGE_SIZE = 1 << 14;
public static final int PAGE_SIZE = 1 << 14;
private static final int PAGE_MASK = PAGE_SIZE - 1;
private static final int PAGE_SHIFT = Integer.numberOfTrailingZeros(PAGE_SIZE);
private static final ByteBuffer[] EMPTY_BYTE_BUFFER_ARRAY = new ByteBuffer[0];
private static final Page[] EMPTY_BYTE_PAGE_ARRAY = new Page[0];
private final ArrayDeque<Page> pages;
private final Supplier<Page> pageSupplier;
private final IntFunction<Page> pageAllocator;
private final ArrayDeque<Page> pages = new ArrayDeque<>();
private final AtomicBoolean isClosed = new AtomicBoolean(false);
private long capacity = 0;
@ -53,14 +52,12 @@ public final class InboundChannelBuffer implements AutoCloseable {
// The offset is an int as it is the offset of where the bytes begin in the first buffer
private int offset = 0;
public InboundChannelBuffer(Supplier<Page> pageSupplier) {
this.pageSupplier = pageSupplier;
this.pages = new ArrayDeque<>();
this.capacity = PAGE_SIZE * pages.size();
public InboundChannelBuffer(IntFunction<Page> pageAllocator) {
this.pageAllocator = pageAllocator;
}
public static InboundChannelBuffer allocatingInstance() {
return new InboundChannelBuffer(() -> new Page(ByteBuffer.allocate(PAGE_SIZE), () -> {}));
return new InboundChannelBuffer((n) -> new Page(ByteBuffer.allocate(n), () -> {}));
}
@Override
@ -87,7 +84,7 @@ public final class InboundChannelBuffer implements AutoCloseable {
int numPages = numPages(requiredCapacity + offset);
int pagesToAdd = numPages - pages.size();
for (int i = 0; i < pagesToAdd; i++) {
Page page = pageSupplier.get();
Page page = pageAllocator.apply(PAGE_SIZE);
pages.addLast(page);
}
capacity += pagesToAdd * PAGE_SIZE;

View File

@ -20,6 +20,7 @@
package org.elasticsearch.nio;
import org.elasticsearch.common.concurrent.CompletableContext;
import org.elasticsearch.nio.utils.ByteBufferUtils;
import org.elasticsearch.nio.utils.ExceptionsHelper;
import java.io.IOException;
@ -249,26 +250,6 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
// data that is copied to the buffer for a write, but not successfully flushed immediately, must be
// copied again on the next call.
protected int readFromChannel(ByteBuffer buffer) throws IOException {
ByteBuffer ioBuffer = getSelector().getIoBuffer();
ioBuffer.limit(Math.min(buffer.remaining(), ioBuffer.limit()));
int bytesRead;
try {
bytesRead = rawChannel.read(ioBuffer);
} catch (IOException e) {
closeNow = true;
throw e;
}
if (bytesRead < 0) {
closeNow = true;
return 0;
} else {
ioBuffer.flip();
buffer.put(ioBuffer);
return bytesRead;
}
}
protected int readFromChannel(InboundChannelBuffer channelBuffer) throws IOException {
ByteBuffer ioBuffer = getSelector().getIoBuffer();
int bytesRead;
@ -288,7 +269,7 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
int j = 0;
while (j < buffers.length && ioBuffer.remaining() > 0) {
ByteBuffer buffer = buffers[j++];
copyBytes(ioBuffer, buffer);
ByteBufferUtils.copyBytes(ioBuffer, buffer);
}
channelBuffer.incrementIndex(bytesRead);
return bytesRead;
@ -299,24 +280,6 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
// copying.
private final int WRITE_LIMIT = 1 << 16;
protected int flushToChannel(ByteBuffer buffer) throws IOException {
int initialPosition = buffer.position();
ByteBuffer ioBuffer = getSelector().getIoBuffer();
ioBuffer.limit(Math.min(WRITE_LIMIT, ioBuffer.limit()));
copyBytes(buffer, ioBuffer);
ioBuffer.flip();
int bytesWritten;
try {
bytesWritten = rawChannel.write(ioBuffer);
} catch (IOException e) {
closeNow = true;
buffer.position(initialPosition);
throw e;
}
buffer.position(initialPosition + bytesWritten);
return bytesWritten;
}
protected int flushToChannel(FlushOperation flushOperation) throws IOException {
ByteBuffer ioBuffer = getSelector().getIoBuffer();
@ -325,12 +288,8 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
while (continueFlush) {
ioBuffer.clear();
ioBuffer.limit(Math.min(WRITE_LIMIT, ioBuffer.limit()));
int j = 0;
ByteBuffer[] buffers = flushOperation.getBuffersToWrite(WRITE_LIMIT);
while (j < buffers.length && ioBuffer.remaining() > 0) {
ByteBuffer buffer = buffers[j++];
copyBytes(buffer, ioBuffer);
}
ByteBufferUtils.copyBytes(buffers, ioBuffer);
ioBuffer.flip();
int bytesFlushed;
try {
@ -345,12 +304,4 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
}
return totalBytesFlushed;
}
private void copyBytes(ByteBuffer from, ByteBuffer to) {
int nBytesToCopy = Math.min(to.remaining(), from.remaining());
int initialLimit = from.limit();
from.limit(from.position() + nBytesToCopy);
to.put(from);
from.limit(initialLimit);
}
}

View File

@ -0,0 +1,63 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.nio.utils;
import java.nio.ByteBuffer;
public final class ByteBufferUtils {
private ByteBufferUtils() {}
/**
* Copies bytes from the array of byte buffers into the destination buffer. The number of bytes copied is
* limited by the bytes available to copy and the space remaining in the destination byte buffer.
*
* @param source byte buffers to copy from
* @param destination byte buffer to copy to
*
* @return number of bytes copied
*/
public static long copyBytes(ByteBuffer[] source, ByteBuffer destination) {
long bytesCopied = 0;
for (int i = 0; i < source.length && destination.hasRemaining(); i++) {
ByteBuffer buffer = source[i];
bytesCopied += copyBytes(buffer, destination);
}
return bytesCopied;
}
/**
* Copies bytes from source byte buffer into the destination buffer. The number of bytes copied is
* limited by the bytes available to copy and the space remaining in the destination byte buffer.
*
* @param source byte buffer to copy from
* @param destination byte buffer to copy to
*
* @return number of bytes copied
*/
public static int copyBytes(ByteBuffer source, ByteBuffer destination) {
int nBytesToCopy = Math.min(destination.remaining(), source.remaining());
int initialLimit = source.limit();
source.limit(source.position() + nBytesToCopy);
destination.put(source);
source.limit(initialLimit);
return nBytesToCopy;
}
}

View File

@ -19,23 +19,25 @@
package org.elasticsearch.nio;
import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.test.ESTestCase;
import java.nio.ByteBuffer;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier;
import java.util.function.IntFunction;
public class InboundChannelBufferTests extends ESTestCase {
private static final int PAGE_SIZE = PageCacheRecycler.PAGE_SIZE_IN_BYTES;
private final Supplier<Page> defaultPageSupplier = () ->
new Page(ByteBuffer.allocate(PageCacheRecycler.BYTE_PAGE_SIZE), () -> {
});
private IntFunction<Page> defaultPageAllocator;
@Override
public void setUp() throws Exception {
super.setUp();
defaultPageAllocator = (n) -> new Page(ByteBuffer.allocate(n), () -> {});
}
public void testNewBufferNoPages() {
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageSupplier);
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageAllocator);
assertEquals(0, channelBuffer.getCapacity());
assertEquals(0, channelBuffer.getRemaining());
@ -43,107 +45,107 @@ public class InboundChannelBufferTests extends ESTestCase {
}
public void testExpandCapacity() {
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageSupplier);
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageAllocator);
assertEquals(0, channelBuffer.getCapacity());
assertEquals(0, channelBuffer.getRemaining());
channelBuffer.ensureCapacity(PAGE_SIZE);
channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE);
assertEquals(PAGE_SIZE, channelBuffer.getCapacity());
assertEquals(PAGE_SIZE, channelBuffer.getRemaining());
assertEquals(InboundChannelBuffer.PAGE_SIZE, channelBuffer.getCapacity());
assertEquals(InboundChannelBuffer.PAGE_SIZE, channelBuffer.getRemaining());
channelBuffer.ensureCapacity(PAGE_SIZE + 1);
channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE + 1);
assertEquals(PAGE_SIZE * 2, channelBuffer.getCapacity());
assertEquals(PAGE_SIZE * 2, channelBuffer.getRemaining());
assertEquals(InboundChannelBuffer.PAGE_SIZE * 2, channelBuffer.getCapacity());
assertEquals(InboundChannelBuffer.PAGE_SIZE * 2, channelBuffer.getRemaining());
}
public void testExpandCapacityMultiplePages() {
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageSupplier);
channelBuffer.ensureCapacity(PAGE_SIZE);
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageAllocator);
channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE);
assertEquals(PAGE_SIZE, channelBuffer.getCapacity());
assertEquals(InboundChannelBuffer.PAGE_SIZE, channelBuffer.getCapacity());
int multiple = randomInt(80);
channelBuffer.ensureCapacity(PAGE_SIZE + ((multiple * PAGE_SIZE) - randomInt(500)));
channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE + ((multiple * InboundChannelBuffer.PAGE_SIZE) - randomInt(500)));
assertEquals(PAGE_SIZE * (multiple + 1), channelBuffer.getCapacity());
assertEquals(PAGE_SIZE * (multiple + 1), channelBuffer.getRemaining());
assertEquals(InboundChannelBuffer.PAGE_SIZE * (multiple + 1), channelBuffer.getCapacity());
assertEquals(InboundChannelBuffer.PAGE_SIZE * (multiple + 1), channelBuffer.getRemaining());
}
public void testExpandCapacityRespectsOffset() {
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageSupplier);
channelBuffer.ensureCapacity(PAGE_SIZE);
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageAllocator);
channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE);
assertEquals(PAGE_SIZE, channelBuffer.getCapacity());
assertEquals(PAGE_SIZE, channelBuffer.getRemaining());
assertEquals(InboundChannelBuffer.PAGE_SIZE, channelBuffer.getCapacity());
assertEquals(InboundChannelBuffer.PAGE_SIZE, channelBuffer.getRemaining());
int offset = randomInt(300);
channelBuffer.release(offset);
assertEquals(PAGE_SIZE - offset, channelBuffer.getCapacity());
assertEquals(PAGE_SIZE - offset, channelBuffer.getRemaining());
assertEquals(InboundChannelBuffer.PAGE_SIZE - offset, channelBuffer.getCapacity());
assertEquals(InboundChannelBuffer.PAGE_SIZE - offset, channelBuffer.getRemaining());
channelBuffer.ensureCapacity(PAGE_SIZE + 1);
channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE + 1);
assertEquals(PAGE_SIZE * 2 - offset, channelBuffer.getCapacity());
assertEquals(PAGE_SIZE * 2 - offset, channelBuffer.getRemaining());
assertEquals(InboundChannelBuffer.PAGE_SIZE * 2 - offset, channelBuffer.getCapacity());
assertEquals(InboundChannelBuffer.PAGE_SIZE * 2 - offset, channelBuffer.getRemaining());
}
public void testIncrementIndex() {
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageSupplier);
channelBuffer.ensureCapacity(PAGE_SIZE);
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageAllocator);
channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE);
assertEquals(0, channelBuffer.getIndex());
assertEquals(PAGE_SIZE, channelBuffer.getRemaining());
assertEquals(InboundChannelBuffer.PAGE_SIZE, channelBuffer.getRemaining());
channelBuffer.incrementIndex(10);
assertEquals(10, channelBuffer.getIndex());
assertEquals(PAGE_SIZE - 10, channelBuffer.getRemaining());
assertEquals(InboundChannelBuffer.PAGE_SIZE - 10, channelBuffer.getRemaining());
}
public void testIncrementIndexWithOffset() {
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageSupplier);
channelBuffer.ensureCapacity(PAGE_SIZE);
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageAllocator);
channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE);
assertEquals(0, channelBuffer.getIndex());
assertEquals(PAGE_SIZE, channelBuffer.getRemaining());
assertEquals(InboundChannelBuffer.PAGE_SIZE, channelBuffer.getRemaining());
channelBuffer.release(10);
assertEquals(PAGE_SIZE - 10, channelBuffer.getRemaining());
assertEquals(InboundChannelBuffer.PAGE_SIZE - 10, channelBuffer.getRemaining());
channelBuffer.incrementIndex(10);
assertEquals(10, channelBuffer.getIndex());
assertEquals(PAGE_SIZE - 20, channelBuffer.getRemaining());
assertEquals(InboundChannelBuffer.PAGE_SIZE - 20, channelBuffer.getRemaining());
channelBuffer.release(2);
assertEquals(8, channelBuffer.getIndex());
assertEquals(PAGE_SIZE - 20, channelBuffer.getRemaining());
assertEquals(InboundChannelBuffer.PAGE_SIZE - 20, channelBuffer.getRemaining());
}
public void testReleaseClosesPages() {
ConcurrentLinkedQueue<AtomicBoolean> queue = new ConcurrentLinkedQueue<>();
Supplier<Page> supplier = () -> {
IntFunction<Page> allocator = (n) -> {
AtomicBoolean atomicBoolean = new AtomicBoolean();
queue.add(atomicBoolean);
return new Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true));
return new Page(ByteBuffer.allocate(n), () -> atomicBoolean.set(true));
};
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(supplier);
channelBuffer.ensureCapacity(PAGE_SIZE * 4);
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(allocator);
channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE * 4);
assertEquals(PAGE_SIZE * 4, channelBuffer.getCapacity());
assertEquals(InboundChannelBuffer.PAGE_SIZE * 4, channelBuffer.getCapacity());
assertEquals(4, queue.size());
for (AtomicBoolean closedRef : queue) {
assertFalse(closedRef.get());
}
channelBuffer.release(2 * PAGE_SIZE);
channelBuffer.release(2 * InboundChannelBuffer.PAGE_SIZE);
assertEquals(PAGE_SIZE * 2, channelBuffer.getCapacity());
assertEquals(InboundChannelBuffer.PAGE_SIZE * 2, channelBuffer.getCapacity());
assertTrue(queue.poll().get());
assertTrue(queue.poll().get());
@ -153,13 +155,13 @@ public class InboundChannelBufferTests extends ESTestCase {
public void testClose() {
ConcurrentLinkedQueue<AtomicBoolean> queue = new ConcurrentLinkedQueue<>();
Supplier<Page> supplier = () -> {
IntFunction<Page> allocator = (n) -> {
AtomicBoolean atomicBoolean = new AtomicBoolean();
queue.add(atomicBoolean);
return new Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true));
return new Page(ByteBuffer.allocate(n), () -> atomicBoolean.set(true));
};
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(supplier);
channelBuffer.ensureCapacity(PAGE_SIZE * 4);
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(allocator);
channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE * 4);
assertEquals(4, queue.size());
@ -178,13 +180,13 @@ public class InboundChannelBufferTests extends ESTestCase {
public void testCloseRetainedPages() {
ConcurrentLinkedQueue<AtomicBoolean> queue = new ConcurrentLinkedQueue<>();
Supplier<Page> supplier = () -> {
IntFunction<Page> allocator = (n) -> {
AtomicBoolean atomicBoolean = new AtomicBoolean();
queue.add(atomicBoolean);
return new Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true));
return new Page(ByteBuffer.allocate(n), () -> atomicBoolean.set(true));
};
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(supplier);
channelBuffer.ensureCapacity(PAGE_SIZE * 4);
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(allocator);
channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE * 4);
assertEquals(4, queue.size());
@ -192,7 +194,7 @@ public class InboundChannelBufferTests extends ESTestCase {
assertFalse(closedRef.get());
}
Page[] pages = channelBuffer.sliceAndRetainPagesTo(PAGE_SIZE * 2);
Page[] pages = channelBuffer.sliceAndRetainPagesTo(InboundChannelBuffer.PAGE_SIZE * 2);
pages[1].close();
@ -220,10 +222,10 @@ public class InboundChannelBufferTests extends ESTestCase {
}
public void testAccessByteBuffers() {
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageSupplier);
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageAllocator);
int pages = randomInt(50) + 5;
channelBuffer.ensureCapacity(pages * PAGE_SIZE);
channelBuffer.ensureCapacity(pages * InboundChannelBuffer.PAGE_SIZE);
long capacity = channelBuffer.getCapacity();

View File

@ -34,8 +34,8 @@ import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.IntFunction;
import java.util.function.Predicate;
import java.util.function.Supplier;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
@ -285,8 +285,8 @@ public class SocketChannelContextTests extends ESTestCase {
when(channel.getRawChannel()).thenReturn(realChannel);
when(channel.isOpen()).thenReturn(true);
Runnable closer = mock(Runnable.class);
Supplier<Page> pageSupplier = () -> new Page(ByteBuffer.allocate(1 << 14), closer);
InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier);
IntFunction<Page> pageAllocator = (n) -> new Page(ByteBuffer.allocate(n), closer);
InboundChannelBuffer buffer = new InboundChannelBuffer(pageAllocator);
buffer.ensureCapacity(1);
TestSocketChannelContext context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, buffer);
context.closeFromSelector();
@ -294,29 +294,6 @@ public class SocketChannelContextTests extends ESTestCase {
}
}
public void testReadToBufferLimitsToPassedBuffer() throws IOException {
ByteBuffer buffer = ByteBuffer.allocate(10);
when(rawChannel.read(any(ByteBuffer.class))).thenAnswer(completelyFillBufferAnswer());
int bytesRead = context.readFromChannel(buffer);
assertEquals(bytesRead, 10);
assertEquals(0, buffer.remaining());
}
public void testReadToBufferHandlesIOException() throws IOException {
when(rawChannel.read(any(ByteBuffer.class))).thenThrow(new IOException());
expectThrows(IOException.class, () -> context.readFromChannel(ByteBuffer.allocate(10)));
assertTrue(context.closeNow());
}
public void testReadToBufferHandlesEOF() throws IOException {
when(rawChannel.read(any(ByteBuffer.class))).thenReturn(-1);
context.readFromChannel(ByteBuffer.allocate(10));
assertTrue(context.closeNow());
}
public void testReadToChannelBufferWillReadAsMuchAsIOBufferAllows() throws IOException {
when(rawChannel.read(any(ByteBuffer.class))).thenAnswer(completelyFillBufferAnswer());
@ -344,33 +321,6 @@ public class SocketChannelContextTests extends ESTestCase {
assertEquals(0, channelBuffer.getIndex());
}
public void testFlushBufferHandlesPartialFlush() throws IOException {
int bytesToConsume = 3;
when(rawChannel.write(any(ByteBuffer.class))).thenAnswer(consumeBufferAnswer(bytesToConsume));
ByteBuffer buffer = ByteBuffer.allocate(10);
context.flushToChannel(buffer);
assertEquals(10 - bytesToConsume, buffer.remaining());
}
public void testFlushBufferHandlesFullFlush() throws IOException {
int bytesToConsume = 10;
when(rawChannel.write(any(ByteBuffer.class))).thenAnswer(consumeBufferAnswer(bytesToConsume));
ByteBuffer buffer = ByteBuffer.allocate(10);
context.flushToChannel(buffer);
assertEquals(0, buffer.remaining());
}
public void testFlushBufferHandlesIOException() throws IOException {
when(rawChannel.write(any(ByteBuffer.class))).thenThrow(new IOException());
ByteBuffer buffer = ByteBuffer.allocate(10);
expectThrows(IOException.class, () -> context.flushToChannel(buffer));
assertTrue(context.closeNow());
assertEquals(10, buffer.remaining());
}
public void testFlushBuffersHandlesZeroFlush() throws IOException {
when(rawChannel.write(any(ByteBuffer.class))).thenAnswer(consumeBufferAnswer(0));
@ -456,22 +406,14 @@ public class SocketChannelContextTests extends ESTestCase {
@Override
public int read() throws IOException {
if (randomBoolean()) {
InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance();
return readFromChannel(channelBuffer);
} else {
return readFromChannel(ByteBuffer.allocate(10));
}
InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance();
return readFromChannel(channelBuffer);
}
@Override
public void flushChannel() throws IOException {
if (randomBoolean()) {
ByteBuffer[] byteBuffers = {ByteBuffer.allocate(10)};
flushToChannel(new FlushOperation(byteBuffers, (v, e) -> {}));
} else {
flushToChannel(ByteBuffer.allocate(10));
}
ByteBuffer[] byteBuffers = {ByteBuffer.allocate(10)};
flushToChannel(new FlushOperation(byteBuffers, (v, e) -> {}));
}
@Override

View File

@ -25,7 +25,6 @@ import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.recycler.Recycler;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.settings.SettingsException;
import org.elasticsearch.common.unit.ByteSizeValue;
@ -43,16 +42,15 @@ import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioGroup;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.Page;
import org.elasticsearch.nio.ServerChannelContext;
import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.rest.RestUtils;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.nio.NioGroupFactory;
import org.elasticsearch.transport.nio.PageAllocator;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.util.Arrays;
@ -80,8 +78,8 @@ import static org.elasticsearch.http.nio.cors.NioCorsHandler.ANY_ORIGIN;
public class NioHttpServerTransport extends AbstractHttpServerTransport {
private static final Logger logger = LogManager.getLogger(NioHttpServerTransport.class);
protected final PageCacheRecycler pageCacheRecycler;
protected final NioCorsConfig corsConfig;
protected final PageAllocator pageAllocator;
private final NioGroupFactory nioGroupFactory;
protected final boolean tcpNoDelay;
@ -97,7 +95,7 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport {
PageCacheRecycler pageCacheRecycler, ThreadPool threadPool, NamedXContentRegistry xContentRegistry,
Dispatcher dispatcher, NioGroupFactory nioGroupFactory) {
super(settings, networkService, bigArrays, threadPool, xContentRegistry, dispatcher);
this.pageCacheRecycler = pageCacheRecycler;
this.pageAllocator = new PageAllocator(pageCacheRecycler);
this.nioGroupFactory = nioGroupFactory;
ByteSizeValue maxChunkSize = SETTING_HTTP_MAX_CHUNK_SIZE.get(settings);
@ -206,15 +204,11 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport {
@Override
public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException {
NioHttpChannel httpChannel = new NioHttpChannel(channel);
java.util.function.Supplier<Page> pageSupplier = () -> {
Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
return new Page(ByteBuffer.wrap(bytes.v()), bytes::close);
};
HttpReadWriteHandler httpReadWritePipeline = new HttpReadWriteHandler(httpChannel,NioHttpServerTransport.this,
handlingSettings, corsConfig, selector.getTaskScheduler(), threadPool::relativeTimeInMillis);
Consumer<Exception> exceptionHandler = (e) -> onException(httpChannel, e);
SocketChannelContext context = new BytesChannelContext(httpChannel, selector, exceptionHandler, httpReadWritePipeline,
new InboundChannelBuffer(pageSupplier));
new InboundChannelBuffer(pageAllocator));
httpChannel.setContext(context);
return httpChannel;
}

View File

@ -26,7 +26,6 @@ import org.elasticsearch.Version;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.recycler.Recycler;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
@ -36,20 +35,17 @@ import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioGroup;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.Page;
import org.elasticsearch.nio.ServerChannelContext;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TcpTransport;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.util.concurrent.ConcurrentMap;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentMap;
@ -57,6 +53,7 @@ public class NioTransport extends TcpTransport {
private static final Logger logger = LogManager.getLogger(NioTransport.class);
protected final PageAllocator pageAllocator;
private final ConcurrentMap<String, TcpChannelFactory> profileToChannelFactory = newConcurrentMap();
private final NioGroupFactory groupFactory;
private volatile NioGroup nioGroup;
@ -66,6 +63,7 @@ public class NioTransport extends TcpTransport {
PageCacheRecycler pageCacheRecycler, NamedWriteableRegistry namedWriteableRegistry,
CircuitBreakerService circuitBreakerService, NioGroupFactory groupFactory) {
super(settings, version, threadPool, pageCacheRecycler, circuitBreakerService, namedWriteableRegistry, networkService);
this.pageAllocator = new PageAllocator(pageCacheRecycler);
this.groupFactory = groupFactory;
}
@ -158,14 +156,10 @@ public class NioTransport extends TcpTransport {
@Override
public NioTcpChannel createChannel(NioSelector selector, SocketChannel channel) {
NioTcpChannel nioChannel = new NioTcpChannel(isClient == false, profileName, channel);
Supplier<Page> pageSupplier = () -> {
Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
return new Page(ByteBuffer.wrap(bytes.v()), bytes::close);
};
TcpReadWriteHandler readWriteHandler = new TcpReadWriteHandler(nioChannel, NioTransport.this);
Consumer<Exception> exceptionHandler = (e) -> onException(nioChannel, e);
BytesChannelContext context = new BytesChannelContext(nioChannel, selector, exceptionHandler, readWriteHandler,
new InboundChannelBuffer(pageSupplier));
new InboundChannelBuffer(pageAllocator));
nioChannel.setContext(context);
return nioChannel;
}

View File

@ -0,0 +1,48 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.elasticsearch.common.recycler.Recycler;
import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.nio.Page;
import java.nio.ByteBuffer;
import java.util.function.IntFunction;
public class PageAllocator implements IntFunction<Page> {
private static final int RECYCLE_LOWER_THRESHOLD = PageCacheRecycler.BYTE_PAGE_SIZE / 2;
private final PageCacheRecycler recycler;
public PageAllocator(PageCacheRecycler recycler) {
this.recycler = recycler;
}
@Override
public Page apply(int length) {
if (length >= RECYCLE_LOWER_THRESHOLD && length <= PageCacheRecycler.BYTE_PAGE_SIZE){
Recycler.V<byte[]> bytePage = recycler.bytePage(false);
return new Page(ByteBuffer.wrap(bytePage.v(), 0, length), bytePage::close);
} else {
return new Page(ByteBuffer.allocate(length), () -> {});
}
}
}

View File

@ -37,8 +37,8 @@ import org.elasticsearch.nio.BytesChannelContext;
import org.elasticsearch.nio.BytesWriteHandler;
import org.elasticsearch.nio.ChannelFactory;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSelectorGroup;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioSelectorGroup;
import org.elasticsearch.nio.NioServerSocketChannel;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.Page;
@ -61,7 +61,7 @@ import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.ConcurrentMap;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.function.IntFunction;
import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentMap;
import static org.elasticsearch.common.util.concurrent.EsExecutors.daemonThreadFactory;
@ -192,9 +192,13 @@ public class MockNioTransport extends TcpTransport {
@Override
public MockSocketChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException {
MockSocketChannel nioChannel = new MockSocketChannel(isClient == false, profileName, channel);
Supplier<Page> pageSupplier = () -> {
Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
return new Page(ByteBuffer.wrap(bytes.v()), bytes::close);
IntFunction<Page> pageSupplier = (length) -> {
if (length > PageCacheRecycler.BYTE_PAGE_SIZE) {
return new Page(ByteBuffer.allocate(length), () -> {});
} else {
Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
return new Page(ByteBuffer.wrap(bytes.v(), 0, length), bytes::close);
}
};
MockTcpReadWriteHandler readWriteHandler = new MockTcpReadWriteHandler(nioChannel, MockNioTransport.this);
BytesChannelContext context = new BytesChannelContext(nioChannel, selector, (e) -> exceptionCaught(nioChannel, e),

View File

@ -36,19 +36,22 @@ public final class SSLChannelContext extends SocketChannelContext {
private static final Runnable DEFAULT_TIMEOUT_CANCELLER = () -> {};
private final SSLDriver sslDriver;
private final InboundChannelBuffer networkReadBuffer;
private final LinkedList<FlushOperation> encryptedFlushes = new LinkedList<>();
private Runnable closeTimeoutCanceller = DEFAULT_TIMEOUT_CANCELLER;
SSLChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler, SSLDriver sslDriver,
ReadWriteHandler readWriteHandler, InboundChannelBuffer channelBuffer) {
this(channel, selector, exceptionHandler, sslDriver, readWriteHandler, channelBuffer, ALWAYS_ALLOW_CHANNEL);
ReadWriteHandler readWriteHandler, InboundChannelBuffer applicationBuffer) {
this(channel, selector, exceptionHandler, sslDriver, readWriteHandler, InboundChannelBuffer.allocatingInstance(),
applicationBuffer, ALWAYS_ALLOW_CHANNEL);
}
SSLChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler, SSLDriver sslDriver,
ReadWriteHandler readWriteHandler, InboundChannelBuffer channelBuffer,
ReadWriteHandler readWriteHandler, InboundChannelBuffer networkReadBuffer, InboundChannelBuffer channelBuffer,
Predicate<NioSocketChannel> allowChannelPredicate) {
super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer, allowChannelPredicate);
this.sslDriver = sslDriver;
this.networkReadBuffer = networkReadBuffer;
}
@Override
@ -157,12 +160,12 @@ public final class SSLChannelContext extends SocketChannelContext {
if (closeNow()) {
return bytesRead;
}
bytesRead = readFromChannel(sslDriver.getNetworkReadBuffer());
bytesRead = readFromChannel(networkReadBuffer);
if (bytesRead == 0) {
return bytesRead;
}
sslDriver.read(channelBuffer);
sslDriver.read(networkReadBuffer, channelBuffer);
handleReadBytes();
// It is possible that a read call produced non-application bytes to flush
@ -201,7 +204,7 @@ public final class SSLChannelContext extends SocketChannelContext {
getSelector().executeFailedListener(encryptedFlush.getListener(), new ClosedChannelException());
}
encryptedFlushes.clear();
IOUtils.close(super::closeFromSelector, sslDriver::close);
IOUtils.close(super::closeFromSelector, networkReadBuffer::close, sslDriver::close);
}
}

View File

@ -8,6 +8,7 @@ package org.elasticsearch.xpack.security.transport.nio;
import org.elasticsearch.nio.FlushOperation;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.Page;
import org.elasticsearch.nio.utils.ByteBufferUtils;
import org.elasticsearch.nio.utils.ExceptionsHelper;
import javax.net.ssl.SSLEngine;
@ -16,6 +17,7 @@ import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.function.IntFunction;
/**
* SSLDriver is a class that wraps the {@link SSLEngine} and attempts to simplify the API. The basic usage is
@ -27,9 +29,9 @@ import java.util.ArrayList;
* application to be written to the wire.
*
* Handling reads from a channel with this class is very simple. When data has been read, call
* {@link #read(InboundChannelBuffer)}. If the data is application data, it will be decrypted and placed into
* the buffer passed as an argument. Otherwise, it will be consumed internally and advance the SSL/TLS close
* or handshake process.
* {@link #read(InboundChannelBuffer, InboundChannelBuffer)}. If the data is application data, it will be
* decrypted and placed into the application buffer passed as an argument. Otherwise, it will be consumed
* internally and advance the SSL/TLS close or handshake process.
*
* Producing writes for a channel is more complicated. The method {@link #needsNonApplicationWrite()} can be
* called to determine if this driver needs to produce more data to advance the handshake or close process.
@ -54,21 +56,22 @@ public class SSLDriver implements AutoCloseable {
private static final FlushOperation EMPTY_FLUSH_OPERATION = new FlushOperation(EMPTY_BUFFERS, (r, t) -> {});
private final SSLEngine engine;
// TODO: When the bytes are actually recycled, we need to test that they are released on driver close
private final SSLOutboundBuffer outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n)));
private final IntFunction<Page> pageAllocator;
private final SSLOutboundBuffer outboundBuffer;
private Page networkReadPage;
private final boolean isClientMode;
// This should only be accessed by the network thread associated with this channel, so nothing needs to
// be volatile.
private Mode currentMode = new HandshakeMode();
private ByteBuffer networkReadBuffer;
private int packetSize;
public SSLDriver(SSLEngine engine, boolean isClientMode) {
public SSLDriver(SSLEngine engine, IntFunction<Page> pageAllocator, boolean isClientMode) {
this.engine = engine;
this.pageAllocator = pageAllocator;
this.outboundBuffer = new SSLOutboundBuffer(pageAllocator);
this.isClientMode = isClientMode;
SSLSession session = engine.getSession();
packetSize = session.getPacketBufferSize();
this.networkReadBuffer = ByteBuffer.allocate(packetSize);
}
public void init() throws SSLException {
@ -106,22 +109,25 @@ public class SSLDriver implements AutoCloseable {
return currentMode.isHandshake();
}
public ByteBuffer getNetworkReadBuffer() {
return networkReadBuffer;
}
public SSLOutboundBuffer getOutboundBuffer() {
return outboundBuffer;
}
public void read(InboundChannelBuffer buffer) throws SSLException {
Mode modePriorToRead;
do {
modePriorToRead = currentMode;
currentMode.read(buffer);
// If we switched modes we want to read again as there might be unhandled bytes that need to be
// handled by the new mode.
} while (modePriorToRead != currentMode);
public void read(InboundChannelBuffer encryptedBuffer, InboundChannelBuffer applicationBuffer) throws SSLException {
networkReadPage = pageAllocator.apply(packetSize);
try {
Mode modePriorToRead;
do {
modePriorToRead = currentMode;
currentMode.read(encryptedBuffer, applicationBuffer);
// It is possible that we received multiple SSL packets from the network since the last read.
// If one of those packets causes us to change modes (such as finished handshaking), we need
// to call read in the new mode to handle the remaining packets.
} while (modePriorToRead != currentMode);
} finally {
networkReadPage.close();
networkReadPage = null;
}
}
public boolean readyForApplicationWrites() {
@ -171,27 +177,34 @@ public class SSLDriver implements AutoCloseable {
ExceptionsHelper.rethrowAndSuppress(closingExceptions);
}
private SSLEngineResult unwrap(InboundChannelBuffer buffer) throws SSLException {
private SSLEngineResult unwrap(InboundChannelBuffer networkBuffer, InboundChannelBuffer applicationBuffer) throws SSLException {
while (true) {
SSLEngineResult result = engine.unwrap(networkReadBuffer, buffer.sliceBuffersFrom(buffer.getIndex()));
buffer.incrementIndex(result.bytesProduced());
ensureApplicationBufferSize(applicationBuffer);
ByteBuffer networkReadBuffer = networkReadPage.byteBuffer();
networkReadBuffer.clear();
ByteBufferUtils.copyBytes(networkBuffer.sliceBuffersTo(Math.min(networkBuffer.getIndex(), packetSize)), networkReadBuffer);
networkReadBuffer.flip();
SSLEngineResult result = engine.unwrap(networkReadBuffer, applicationBuffer.sliceBuffersFrom(applicationBuffer.getIndex()));
networkBuffer.release(result.bytesConsumed());
applicationBuffer.incrementIndex(result.bytesProduced());
switch (result.getStatus()) {
case OK:
networkReadBuffer.compact();
return result;
case BUFFER_UNDERFLOW:
// There is not enough space in the network buffer for an entire SSL packet. Compact the
// current data and expand the buffer if necessary.
int currentCapacity = networkReadBuffer.capacity();
ensureNetworkReadBufferSize();
if (currentCapacity == networkReadBuffer.capacity()) {
networkReadBuffer.compact();
packetSize = engine.getSession().getPacketBufferSize();
if (networkReadPage.byteBuffer().capacity() < packetSize) {
networkReadPage.close();
networkReadPage = pageAllocator.apply(packetSize);
} else {
return result;
}
return result;
break;
case BUFFER_OVERFLOW:
// There is not enough space in the application buffer for the decrypted message. Expand
// the application buffer to ensure that it has enough space.
ensureApplicationBufferSize(buffer);
ensureApplicationBufferSize(applicationBuffer);
break;
case CLOSED:
assert engine.isInboundDone() : "We received close_notify so read should be done";
@ -254,15 +267,6 @@ public class SSLDriver implements AutoCloseable {
}
}
private void ensureNetworkReadBufferSize() {
packetSize = engine.getSession().getPacketBufferSize();
if (networkReadBuffer.capacity() < packetSize) {
ByteBuffer newBuffer = ByteBuffer.allocate(packetSize);
networkReadBuffer.flip();
newBuffer.put(networkReadBuffer);
}
}
// There are three potential modes for the driver to be in - HANDSHAKE, APPLICATION, or CLOSE. HANDSHAKE
// is the initial mode. During this mode data that is read and written will be related to the TLS
// handshake process. Application related data cannot be encrypted until the handshake is complete. From
@ -282,7 +286,7 @@ public class SSLDriver implements AutoCloseable {
private interface Mode {
void read(InboundChannelBuffer buffer) throws SSLException;
void read(InboundChannelBuffer encryptedBuffer, InboundChannelBuffer applicationBuffer) throws SSLException;
int write(FlushOperation applicationBytes) throws SSLException;
@ -342,13 +346,11 @@ public class SSLDriver implements AutoCloseable {
}
@Override
public void read(InboundChannelBuffer buffer) throws SSLException {
ensureApplicationBufferSize(buffer);
public void read(InboundChannelBuffer encryptedBuffer, InboundChannelBuffer applicationBuffer) throws SSLException {
boolean continueUnwrap = true;
while (continueUnwrap && networkReadBuffer.position() > 0) {
networkReadBuffer.flip();
while (continueUnwrap && encryptedBuffer.getIndex() > 0) {
try {
SSLEngineResult result = unwrap(buffer);
SSLEngineResult result = unwrap(encryptedBuffer, applicationBuffer);
handshakeStatus = result.getHandshakeStatus();
handshake();
// If we are done handshaking we should exit the handshake read
@ -430,12 +432,10 @@ public class SSLDriver implements AutoCloseable {
private class ApplicationMode implements Mode {
@Override
public void read(InboundChannelBuffer buffer) throws SSLException {
ensureApplicationBufferSize(buffer);
public void read(InboundChannelBuffer encryptedBuffer, InboundChannelBuffer applicationBuffer) throws SSLException {
boolean continueUnwrap = true;
while (continueUnwrap && networkReadBuffer.position() > 0) {
networkReadBuffer.flip();
SSLEngineResult result = unwrap(buffer);
while (continueUnwrap && encryptedBuffer.getIndex() > 0) {
SSLEngineResult result = unwrap(encryptedBuffer, applicationBuffer);
boolean renegotiationRequested = result.getStatus() != SSLEngineResult.Status.CLOSED
&& maybeRenegotiation(result.getHandshakeStatus());
continueUnwrap = result.bytesProduced() > 0 && renegotiationRequested == false;
@ -515,7 +515,7 @@ public class SSLDriver implements AutoCloseable {
}
@Override
public void read(InboundChannelBuffer buffer) throws SSLException {
public void read(InboundChannelBuffer encryptedBuffer, InboundChannelBuffer applicationBuffer) throws SSLException {
if (needToReceiveClose == false) {
// There is an issue where receiving handshake messages after initiating the close process
// can place the SSLEngine back into handshaking mode. In order to handle this, if we
@ -524,11 +524,9 @@ public class SSLDriver implements AutoCloseable {
return;
}
ensureApplicationBufferSize(buffer);
boolean continueUnwrap = true;
while (continueUnwrap && networkReadBuffer.position() > 0) {
networkReadBuffer.flip();
SSLEngineResult result = unwrap(buffer);
while (continueUnwrap && encryptedBuffer.getIndex() > 0) {
SSLEngineResult result = unwrap(encryptedBuffer, applicationBuffer);
continueUnwrap = result.bytesProduced() > 0 || result.bytesConsumed() > 0;
}
if (engine.isInboundDone()) {

View File

@ -8,7 +8,6 @@ package org.elasticsearch.xpack.security.transport.nio;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.recycler.Recycler;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.PageCacheRecycler;
@ -22,7 +21,6 @@ import org.elasticsearch.nio.ChannelFactory;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.Page;
import org.elasticsearch.nio.ServerChannelContext;
import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.threadpool.ThreadPool;
@ -35,11 +33,9 @@ import org.elasticsearch.xpack.security.transport.filter.IPFilter;
import javax.net.ssl.SSLEngine;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.util.function.Consumer;
import java.util.function.Supplier;
import static org.elasticsearch.xpack.core.XPackSettings.HTTP_SSL_ENABLED;
@ -93,13 +89,9 @@ public class SecurityNioHttpServerTransport extends NioHttpServerTransport {
@Override
public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException {
NioHttpChannel httpChannel = new NioHttpChannel(channel);
Supplier<Page> pageSupplier = () -> {
Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
return new Page(ByteBuffer.wrap(bytes.v()), bytes::close);
};
HttpReadWriteHandler httpHandler = new HttpReadWriteHandler(httpChannel,SecurityNioHttpServerTransport.this,
handlingSettings, corsConfig, selector.getTaskScheduler(), threadPool::relativeTimeInNanos);
InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier);
InboundChannelBuffer networkBuffer = new InboundChannelBuffer(pageAllocator);
Consumer<Exception> exceptionHandler = (e) -> securityExceptionHandler.accept(httpChannel, e);
SocketChannelContext context;
@ -113,10 +105,12 @@ public class SecurityNioHttpServerTransport extends NioHttpServerTransport {
} else {
sslEngine = sslService.createSSLEngine(sslConfiguration, null, -1);
}
SSLDriver sslDriver = new SSLDriver(sslEngine, false);
context = new SSLChannelContext(httpChannel, selector, exceptionHandler, sslDriver, httpHandler, buffer, nioIpFilter);
SSLDriver sslDriver = new SSLDriver(sslEngine, pageAllocator, false);
InboundChannelBuffer applicationBuffer = new InboundChannelBuffer(pageAllocator);
context = new SSLChannelContext(httpChannel, selector, exceptionHandler, sslDriver, httpHandler, networkBuffer,
applicationBuffer, nioIpFilter);
} else {
context = new BytesChannelContext(httpChannel, selector, exceptionHandler, httpHandler, buffer, nioIpFilter);
context = new BytesChannelContext(httpChannel, selector, exceptionHandler, httpHandler, networkBuffer, nioIpFilter);
}
httpChannel.setContext(context);

View File

@ -12,7 +12,6 @@ import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.recycler.Recycler;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
@ -21,7 +20,6 @@ import org.elasticsearch.nio.ChannelFactory;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.Page;
import org.elasticsearch.nio.ServerChannelContext;
import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.threadpool.ThreadPool;
@ -45,14 +43,12 @@ import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLParameters;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.util.Collections;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import static org.elasticsearch.xpack.core.security.SecurityField.setting;
@ -156,20 +152,18 @@ public class SecurityNioTransport extends NioTransport {
@Override
public NioTcpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException {
NioTcpChannel nioChannel = new NioTcpChannel(isClient == false, profileName, channel);
Supplier<Page> pageSupplier = () -> {
Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
return new Page(ByteBuffer.wrap(bytes.v()), bytes::close);
};
TcpReadWriteHandler readWriteHandler = new TcpReadWriteHandler(nioChannel, SecurityNioTransport.this);
InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier);
InboundChannelBuffer networkBuffer = new InboundChannelBuffer(pageAllocator);
Consumer<Exception> exceptionHandler = (e) -> onException(nioChannel, e);
SocketChannelContext context;
if (sslEnabled) {
SSLDriver sslDriver = new SSLDriver(createSSLEngine(channel), isClient);
context = new SSLChannelContext(nioChannel, selector, exceptionHandler, sslDriver, readWriteHandler, buffer, ipFilter);
SSLDriver sslDriver = new SSLDriver(createSSLEngine(channel), pageAllocator, isClient);
InboundChannelBuffer applicationBuffer = new InboundChannelBuffer(pageAllocator);
context = new SSLChannelContext(nioChannel, selector, exceptionHandler, sslDriver, readWriteHandler, networkBuffer,
applicationBuffer, ipFilter);
} else {
context = new BytesChannelContext(nioChannel, selector, exceptionHandler, readWriteHandler, buffer, ipFilter);
context = new BytesChannelContext(nioChannel, selector, exceptionHandler, readWriteHandler, networkBuffer, ipFilter);
}
nioChannel.setContext(context);

View File

@ -52,7 +52,6 @@ public class SSLChannelContextTests extends ESTestCase {
private BiConsumer<Void, Exception> listener;
private Consumer exceptionHandler;
private SSLDriver sslDriver;
private ByteBuffer readBuffer = ByteBuffer.allocate(1 << 14);
private int messageLength;
@Before
@ -76,7 +75,6 @@ public class SSLChannelContextTests extends ESTestCase {
when(selector.isOnCurrentThread()).thenReturn(true);
when(selector.getTaskScheduler()).thenReturn(nioTimer);
when(sslDriver.getNetworkReadBuffer()).thenReturn(readBuffer);
when(sslDriver.getOutboundBuffer()).thenReturn(outboundBuffer);
ByteBuffer buffer = ByteBuffer.allocate(1 << 14);
when(selector.getIoBuffer()).thenAnswer(invocationOnMock -> {
@ -88,8 +86,12 @@ public class SSLChannelContextTests extends ESTestCase {
public void testSuccessfulRead() throws IOException {
byte[] bytes = createMessage(messageLength);
when(rawChannel.read(any(ByteBuffer.class))).thenReturn(bytes.length);
doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer);
when(rawChannel.read(any(ByteBuffer.class))).thenAnswer(invocationOnMock -> {
ByteBuffer buffer = (ByteBuffer) invocationOnMock.getArguments()[0];
buffer.put(bytes);
return bytes.length;
});
doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(any(InboundChannelBuffer.class), eq(channelBuffer));
when(readConsumer.apply(channelBuffer)).thenReturn(messageLength, 0);
@ -103,8 +105,12 @@ public class SSLChannelContextTests extends ESTestCase {
public void testMultipleReadsConsumed() throws IOException {
byte[] bytes = createMessage(messageLength * 2);
when(rawChannel.read(any(ByteBuffer.class))).thenReturn(bytes.length);
doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer);
when(rawChannel.read(any(ByteBuffer.class))).thenAnswer(invocationOnMock -> {
ByteBuffer buffer = (ByteBuffer) invocationOnMock.getArguments()[0];
buffer.put(bytes);
return bytes.length;
});
doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(any(InboundChannelBuffer.class), eq(channelBuffer));
when(readConsumer.apply(channelBuffer)).thenReturn(messageLength, messageLength, 0);
@ -118,8 +124,12 @@ public class SSLChannelContextTests extends ESTestCase {
public void testPartialRead() throws IOException {
byte[] bytes = createMessage(messageLength);
when(rawChannel.read(any(ByteBuffer.class))).thenReturn(bytes.length);
doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer);
when(rawChannel.read(any(ByteBuffer.class))).thenAnswer(invocationOnMock -> {
ByteBuffer buffer = (ByteBuffer) invocationOnMock.getArguments()[0];
buffer.put(bytes);
return bytes.length;
});
doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(any(InboundChannelBuffer.class), eq(channelBuffer));
when(readConsumer.apply(channelBuffer)).thenReturn(0);
@ -424,12 +434,12 @@ public class SSLChannelContextTests extends ESTestCase {
private Answer getReadAnswerForBytes(byte[] bytes) {
return invocationOnMock -> {
InboundChannelBuffer buffer = (InboundChannelBuffer) invocationOnMock.getArguments()[0];
buffer.ensureCapacity(buffer.getIndex() + bytes.length);
ByteBuffer[] buffers = buffer.sliceBuffersFrom(buffer.getIndex());
InboundChannelBuffer appBuffer = (InboundChannelBuffer) invocationOnMock.getArguments()[1];
appBuffer.ensureCapacity(appBuffer.getIndex() + bytes.length);
ByteBuffer[] buffers = appBuffer.sliceBuffersFrom(appBuffer.getIndex());
assert buffers[0].remaining() > bytes.length;
buffers[0].put(bytes);
buffer.incrementIndex(bytes.length);
appBuffer.incrementIndex(bytes.length);
return bytes.length;
};
}

View File

@ -26,14 +26,16 @@ import java.security.SecureRandom;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.function.Supplier;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.IntFunction;
public class SSLDriverTests extends ESTestCase {
private final Supplier<Page> pageSupplier = () -> new Page(ByteBuffer.allocate(1 << 14), () -> {});
private InboundChannelBuffer serverBuffer = new InboundChannelBuffer(pageSupplier);
private InboundChannelBuffer clientBuffer = new InboundChannelBuffer(pageSupplier);
private InboundChannelBuffer genericBuffer = new InboundChannelBuffer(pageSupplier);
private final IntFunction<Page> pageAllocator = (n) -> new Page(ByteBuffer.allocate(n), () -> {});
private final InboundChannelBuffer networkReadBuffer = new InboundChannelBuffer(pageAllocator);
private final InboundChannelBuffer applicationBuffer = new InboundChannelBuffer(pageAllocator);
private final AtomicInteger openPages = new AtomicInteger(0);
public void testPingPongAndClose() throws Exception {
SSLContext sslContext = getSSLContext();
@ -44,19 +46,36 @@ public class SSLDriverTests extends ESTestCase {
handshake(clientDriver, serverDriver);
ByteBuffer[] buffers = {ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8))};
sendAppData(clientDriver, serverDriver, buffers);
serverDriver.read(serverBuffer);
assertEquals(ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8)), serverBuffer.sliceBuffersTo(4)[0]);
sendAppData(clientDriver, buffers);
serverDriver.read(networkReadBuffer, applicationBuffer);
assertEquals(ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8)), applicationBuffer.sliceBuffersTo(4)[0]);
applicationBuffer.release(4);
ByteBuffer[] buffers2 = {ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8))};
sendAppData(serverDriver, clientDriver, buffers2);
clientDriver.read(clientBuffer);
assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), clientBuffer.sliceBuffersTo(4)[0]);
sendAppData(serverDriver, buffers2);
clientDriver.read(networkReadBuffer, applicationBuffer);
assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), applicationBuffer.sliceBuffersTo(4)[0]);
applicationBuffer.release(4);
assertFalse(clientDriver.needsNonApplicationWrite());
normalClose(clientDriver, serverDriver);
}
public void testDataStoredInOutboundBufferIsClosed() throws Exception {
SSLContext sslContext = getSSLContext();
SSLDriver clientDriver = getDriver(sslContext.createSSLEngine(), true);
SSLDriver serverDriver = getDriver(sslContext.createSSLEngine(), false);
handshake(clientDriver, serverDriver);
ByteBuffer[] buffers = {ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8))};
serverDriver.write(new FlushOperation(buffers, (v, e) -> {}));
expectThrows(SSLException.class, serverDriver::close);
assertEquals(0, openPages.get());
}
public void testRenegotiate() throws Exception {
SSLContext sslContext = getSSLContext();
@ -73,9 +92,10 @@ public class SSLDriverTests extends ESTestCase {
handshake(clientDriver, serverDriver);
ByteBuffer[] buffers = {ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8))};
sendAppData(clientDriver, serverDriver, buffers);
serverDriver.read(serverBuffer);
assertEquals(ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8)), serverBuffer.sliceBuffersTo(4)[0]);
sendAppData(clientDriver, buffers);
serverDriver.read(networkReadBuffer, applicationBuffer);
assertEquals(ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8)), applicationBuffer.sliceBuffersTo(4)[0]);
applicationBuffer.release(4);
clientDriver.renegotiate();
assertTrue(clientDriver.isHandshaking());
@ -83,17 +103,20 @@ public class SSLDriverTests extends ESTestCase {
// This tests that the client driver can still receive data based on the prior handshake
ByteBuffer[] buffers2 = {ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8))};
sendAppData(serverDriver, clientDriver, buffers2);
clientDriver.read(clientBuffer);
assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), clientBuffer.sliceBuffersTo(4)[0]);
sendAppData(serverDriver, buffers2);
clientDriver.read(networkReadBuffer, applicationBuffer);
assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), applicationBuffer.sliceBuffersTo(4)[0]);
applicationBuffer.release(4);
handshake(clientDriver, serverDriver, true);
sendAppData(clientDriver, serverDriver, buffers);
serverDriver.read(serverBuffer);
assertEquals(ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8)), serverBuffer.sliceBuffersTo(4)[0]);
sendAppData(serverDriver, clientDriver, buffers2);
clientDriver.read(clientBuffer);
assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), clientBuffer.sliceBuffersTo(4)[0]);
sendAppData(clientDriver, buffers);
serverDriver.read(networkReadBuffer, applicationBuffer);
assertEquals(ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8)), applicationBuffer.sliceBuffersTo(4)[0]);
applicationBuffer.release(4);
sendAppData(serverDriver, buffers2);
clientDriver.read(networkReadBuffer, applicationBuffer);
assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), applicationBuffer.sliceBuffersTo(4)[0]);
applicationBuffer.release(4);
normalClose(clientDriver, serverDriver);
}
@ -108,18 +131,22 @@ public class SSLDriverTests extends ESTestCase {
ByteBuffer buffer = ByteBuffer.allocate(1 << 15);
for (int i = 0; i < (1 << 15); ++i) {
buffer.put((byte) i);
buffer.put((byte) (i % 127));
}
buffer.flip();
ByteBuffer[] buffers = {buffer};
sendAppData(clientDriver, serverDriver, buffers);
serverDriver.read(serverBuffer);
assertEquals(16384, serverBuffer.sliceBuffersFrom(0)[0].limit());
assertEquals(16384, serverBuffer.sliceBuffersFrom(0)[1].limit());
sendAppData(clientDriver, buffers);
serverDriver.read(networkReadBuffer, applicationBuffer);
ByteBuffer[] buffers1 = applicationBuffer.sliceBuffersFrom(0);
assertEquals((byte) (16383 % 127), buffers1[0].get(16383));
assertEquals((byte) (32767 % 127), buffers1[1].get(16383));
applicationBuffer.release(1 << 15);
ByteBuffer[] buffers2 = {ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8))};
sendAppData(serverDriver, clientDriver, buffers2);
clientDriver.read(clientBuffer);
assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), clientBuffer.sliceBuffersTo(4)[0]);
sendAppData(serverDriver, buffers2);
clientDriver.read(networkReadBuffer, applicationBuffer);
assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), applicationBuffer.sliceBuffersTo(4)[0]);
applicationBuffer.release(4);
assertFalse(clientDriver.needsNonApplicationWrite());
normalClose(clientDriver, serverDriver);
@ -193,16 +220,16 @@ public class SSLDriverTests extends ESTestCase {
serverDriver.initiateClose();
assertTrue(serverDriver.needsNonApplicationWrite());
assertFalse(serverDriver.isClosed());
sendNonApplicationWrites(serverDriver, clientDriver);
sendNonApplicationWrites(serverDriver);
// We are immediately fully closed due to SSLEngine inconsistency
assertTrue(serverDriver.isClosed());
// This should not throw exception yet as the SSLEngine will not UNWRAP data while attempting to WRAP
clientDriver.read(clientBuffer);
sendNonApplicationWrites(clientDriver, serverDriver);
clientDriver.read(clientBuffer);
sendNonApplicationWrites(clientDriver, serverDriver);
serverDriver.read(serverBuffer);
SSLException sslException = expectThrows(SSLException.class, () -> clientDriver.read(networkReadBuffer, applicationBuffer));
assertEquals("Received close_notify during handshake", sslException.getMessage());
sendNonApplicationWrites(clientDriver);
assertTrue(clientDriver.isClosed());
serverDriver.read(networkReadBuffer, applicationBuffer);
}
public void testCloseDuringHandshakePreJDK11() throws Exception {
@ -226,26 +253,28 @@ public class SSLDriverTests extends ESTestCase {
serverDriver.initiateClose();
assertTrue(serverDriver.needsNonApplicationWrite());
assertFalse(serverDriver.isClosed());
sendNonApplicationWrites(serverDriver, clientDriver);
sendNonApplicationWrites(serverDriver);
// We are immediately fully closed due to SSLEngine inconsistency
assertTrue(serverDriver.isClosed());
SSLException sslException = expectThrows(SSLException.class, () -> clientDriver.read(clientBuffer));
// This should not throw exception yet as the SSLEngine will not UNWRAP data while attempting to WRAP
SSLException sslException = expectThrows(SSLException.class, () -> clientDriver.read(networkReadBuffer, applicationBuffer));
assertEquals("Received close_notify during handshake", sslException.getMessage());
assertTrue(clientDriver.needsNonApplicationWrite());
sendNonApplicationWrites(clientDriver, serverDriver);
serverDriver.read(serverBuffer);
sendNonApplicationWrites(clientDriver);
assertTrue(clientDriver.isClosed());
serverDriver.read(networkReadBuffer, applicationBuffer);
}
private void failedCloseAlert(SSLDriver sendDriver, SSLDriver receiveDriver, List<String> messages) throws SSLException {
assertTrue(sendDriver.needsNonApplicationWrite());
assertFalse(sendDriver.isClosed());
sendNonApplicationWrites(sendDriver, receiveDriver);
sendNonApplicationWrites(sendDriver);
assertTrue(sendDriver.isClosed());
sendDriver.close();
SSLException sslException = expectThrows(SSLException.class, () -> receiveDriver.read(genericBuffer));
SSLException sslException = expectThrows(SSLException.class, () -> receiveDriver.read(networkReadBuffer, applicationBuffer));
assertTrue("Expected one of the following exception messages: " + messages + ". Found: " + sslException.getMessage(),
messages.stream().anyMatch(m -> sslException.getMessage().equals(m)));
if (receiveDriver.needsNonApplicationWrite() == false) {
@ -274,29 +303,30 @@ public class SSLDriverTests extends ESTestCase {
sendDriver.initiateClose();
assertFalse(sendDriver.readyForApplicationWrites());
assertTrue(sendDriver.needsNonApplicationWrite());
sendNonApplicationWrites(sendDriver, receiveDriver);
sendNonApplicationWrites(sendDriver);
assertFalse(sendDriver.isClosed());
receiveDriver.read(genericBuffer);
receiveDriver.read(networkReadBuffer, applicationBuffer);
assertFalse(receiveDriver.isClosed());
assertFalse(receiveDriver.readyForApplicationWrites());
assertTrue(receiveDriver.needsNonApplicationWrite());
sendNonApplicationWrites(receiveDriver, sendDriver);
sendNonApplicationWrites(receiveDriver);
assertTrue(receiveDriver.isClosed());
sendDriver.read(genericBuffer);
sendDriver.read(networkReadBuffer, applicationBuffer);
assertTrue(sendDriver.isClosed());
sendDriver.close();
receiveDriver.close();
assertEquals(0, openPages.get());
}
private void sendNonApplicationWrites(SSLDriver sendDriver, SSLDriver receiveDriver) throws SSLException {
private void sendNonApplicationWrites(SSLDriver sendDriver) throws SSLException {
SSLOutboundBuffer outboundBuffer = sendDriver.getOutboundBuffer();
while (sendDriver.needsNonApplicationWrite() || outboundBuffer.hasEncryptedBytesToFlush()) {
if (outboundBuffer.hasEncryptedBytesToFlush()) {
sendData(outboundBuffer.buildNetworkFlushOperation(), receiveDriver);
sendData(outboundBuffer.buildNetworkFlushOperation());
} else {
sendDriver.nonApplicationWrite();
}
@ -342,8 +372,8 @@ public class SSLDriverTests extends ESTestCase {
while (sendDriver.needsNonApplicationWrite() || outboundBuffer.hasEncryptedBytesToFlush()) {
if (outboundBuffer.hasEncryptedBytesToFlush()) {
sendData(outboundBuffer.buildNetworkFlushOperation(), receiveDriver);
receiveDriver.read(genericBuffer);
sendData(outboundBuffer.buildNetworkFlushOperation());
receiveDriver.read(networkReadBuffer, applicationBuffer);
} else {
sendDriver.nonApplicationWrite();
}
@ -353,37 +383,46 @@ public class SSLDriverTests extends ESTestCase {
}
}
private void sendAppData(SSLDriver sendDriver, SSLDriver receiveDriver, ByteBuffer[] message) throws IOException {
private void sendAppData(SSLDriver sendDriver, ByteBuffer[] message) throws IOException {
assertFalse(sendDriver.needsNonApplicationWrite());
int bytesToEncrypt = Arrays.stream(message).mapToInt(Buffer::remaining).sum();
SSLOutboundBuffer outboundBuffer = sendDriver.getOutboundBuffer();
FlushOperation flushOperation = new FlushOperation(message, (r, l) -> {});
int bytesEncrypted = 0;
while (bytesToEncrypt > bytesEncrypted) {
bytesEncrypted += sendDriver.write(flushOperation);
sendData(outboundBuffer.buildNetworkFlushOperation(), receiveDriver);
while (flushOperation.isFullyFlushed() == false) {
sendDriver.write(flushOperation);
}
sendData(sendDriver.getOutboundBuffer().buildNetworkFlushOperation());
}
private void sendData(FlushOperation flushOperation, SSLDriver receiveDriver) {
ByteBuffer readBuffer = receiveDriver.getNetworkReadBuffer();
private void sendData(FlushOperation flushOperation) {
ByteBuffer[] writeBuffers = flushOperation.getBuffersToWrite();
int bytesToEncrypt = Arrays.stream(writeBuffers).mapToInt(Buffer::remaining).sum();
assert bytesToEncrypt < readBuffer.capacity() : "Flush operation must be less that read buffer";
int bytesToCopy = Arrays.stream(writeBuffers).mapToInt(Buffer::remaining).sum();
networkReadBuffer.ensureCapacity(bytesToCopy + networkReadBuffer.getIndex());
ByteBuffer[] byteBuffers = networkReadBuffer.sliceBuffersFrom(0);
assert writeBuffers.length > 0 : "No write buffers";
for (ByteBuffer writeBuffer : writeBuffers) {
int written = writeBuffer.remaining();
int r = 0;
while (flushOperation.isFullyFlushed() == false) {
ByteBuffer readBuffer = byteBuffers[r];
ByteBuffer writeBuffer = flushOperation.getBuffersToWrite()[0];
int toWrite = Math.min(writeBuffer.remaining(), readBuffer.remaining());
writeBuffer.limit(writeBuffer.position() + toWrite);
readBuffer.put(writeBuffer);
flushOperation.incrementIndex(written);
flushOperation.incrementIndex(toWrite);
if (readBuffer.remaining() == 0) {
r++;
}
}
networkReadBuffer.incrementIndex(bytesToCopy);
assertTrue(flushOperation.isFullyFlushed());
flushOperation.getListener().accept(null, null);
}
private SSLDriver getDriver(SSLEngine engine, boolean isClient) {
return new SSLDriver(engine, isClient);
return new SSLDriver(engine, (n) -> {
openPages.incrementAndGet();
return new Page(ByteBuffer.allocate(n), openPages::decrementAndGet);
}, isClient);
}
}