From 927013426a1fca0606c75d94c18cb07928f90243 Mon Sep 17 00:00:00 2001 From: Tim Brooks Date: Mon, 6 May 2019 09:51:32 -0600 Subject: [PATCH] Read multiple TLS packets in one read call (#41820) This is related to #27260. Currently we have a single read buffer that is no larger than a single TLS packet. This prevents us from reading multiple TLS packets in a single socket read call. This commit modifies our TLS work to support reading similar to the plaintext case. The data will be copied to a (potentially) recycled TLS packet-sized buffer for interaction with the SSLEngine. --- .../nio/InboundChannelBuffer.java | 19 +- .../nio/SocketChannelContext.java | 55 +----- .../nio/utils/ByteBufferUtils.java | 63 +++++++ .../nio/InboundChannelBufferTests.java | 116 ++++++------ .../nio/SocketChannelContextTests.java | 72 +------ .../http/nio/NioHttpServerTransport.java | 14 +- .../transport/nio/NioTransport.java | 12 +- .../transport/nio/PageAllocator.java | 48 +++++ .../transport/nio/MockNioTransport.java | 14 +- .../transport/nio/SSLChannelContext.java | 15 +- .../security/transport/nio/SSLDriver.java | 108 ++++++----- .../nio/SecurityNioHttpServerTransport.java | 18 +- .../transport/nio/SecurityNioTransport.java | 18 +- .../transport/nio/SSLChannelContextTests.java | 34 ++-- .../transport/nio/SSLDriverTests.java | 175 +++++++++++------- 15 files changed, 407 insertions(+), 374 deletions(-) create mode 100644 libs/nio/src/main/java/org/elasticsearch/nio/utils/ByteBufferUtils.java create mode 100644 plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/PageAllocator.java diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/InboundChannelBuffer.java b/libs/nio/src/main/java/org/elasticsearch/nio/InboundChannelBuffer.java index 2dfd53d27e1..5c3b519e390 100644 --- a/libs/nio/src/main/java/org/elasticsearch/nio/InboundChannelBuffer.java +++ b/libs/nio/src/main/java/org/elasticsearch/nio/InboundChannelBuffer.java @@ -27,7 +27,7 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Supplier; +import java.util.function.IntFunction; /** * This is a channel byte buffer composed internally of 16kb pages. When an entire message has been read @@ -37,15 +37,14 @@ import java.util.function.Supplier; */ public final class InboundChannelBuffer implements AutoCloseable { - private static final int PAGE_SIZE = 1 << 14; + public static final int PAGE_SIZE = 1 << 14; private static final int PAGE_MASK = PAGE_SIZE - 1; private static final int PAGE_SHIFT = Integer.numberOfTrailingZeros(PAGE_SIZE); private static final ByteBuffer[] EMPTY_BYTE_BUFFER_ARRAY = new ByteBuffer[0]; private static final Page[] EMPTY_BYTE_PAGE_ARRAY = new Page[0]; - - private final ArrayDeque pages; - private final Supplier pageSupplier; + private final IntFunction pageAllocator; + private final ArrayDeque pages = new ArrayDeque<>(); private final AtomicBoolean isClosed = new AtomicBoolean(false); private long capacity = 0; @@ -53,14 +52,12 @@ public final class InboundChannelBuffer implements AutoCloseable { // The offset is an int as it is the offset of where the bytes begin in the first buffer private int offset = 0; - public InboundChannelBuffer(Supplier pageSupplier) { - this.pageSupplier = pageSupplier; - this.pages = new ArrayDeque<>(); - this.capacity = PAGE_SIZE * pages.size(); + public InboundChannelBuffer(IntFunction pageAllocator) { + this.pageAllocator = pageAllocator; } public static InboundChannelBuffer allocatingInstance() { - return new InboundChannelBuffer(() -> new Page(ByteBuffer.allocate(PAGE_SIZE), () -> {})); + return new InboundChannelBuffer((n) -> new Page(ByteBuffer.allocate(n), () -> {})); } @Override @@ -87,7 +84,7 @@ public final class InboundChannelBuffer implements AutoCloseable { int numPages = numPages(requiredCapacity + offset); int pagesToAdd = numPages - pages.size(); for (int i = 0; i < pagesToAdd; i++) { - Page page = pageSupplier.get(); + Page page = pageAllocator.apply(PAGE_SIZE); pages.addLast(page); } capacity += pagesToAdd * PAGE_SIZE; diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java b/libs/nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java index a926bbc9710..22d85472126 100644 --- a/libs/nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java +++ b/libs/nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java @@ -20,6 +20,7 @@ package org.elasticsearch.nio; import org.elasticsearch.common.concurrent.CompletableContext; +import org.elasticsearch.nio.utils.ByteBufferUtils; import org.elasticsearch.nio.utils.ExceptionsHelper; import java.io.IOException; @@ -249,26 +250,6 @@ public abstract class SocketChannelContext extends ChannelContext // data that is copied to the buffer for a write, but not successfully flushed immediately, must be // copied again on the next call. - protected int readFromChannel(ByteBuffer buffer) throws IOException { - ByteBuffer ioBuffer = getSelector().getIoBuffer(); - ioBuffer.limit(Math.min(buffer.remaining(), ioBuffer.limit())); - int bytesRead; - try { - bytesRead = rawChannel.read(ioBuffer); - } catch (IOException e) { - closeNow = true; - throw e; - } - if (bytesRead < 0) { - closeNow = true; - return 0; - } else { - ioBuffer.flip(); - buffer.put(ioBuffer); - return bytesRead; - } - } - protected int readFromChannel(InboundChannelBuffer channelBuffer) throws IOException { ByteBuffer ioBuffer = getSelector().getIoBuffer(); int bytesRead; @@ -288,7 +269,7 @@ public abstract class SocketChannelContext extends ChannelContext int j = 0; while (j < buffers.length && ioBuffer.remaining() > 0) { ByteBuffer buffer = buffers[j++]; - copyBytes(ioBuffer, buffer); + ByteBufferUtils.copyBytes(ioBuffer, buffer); } channelBuffer.incrementIndex(bytesRead); return bytesRead; @@ -299,24 +280,6 @@ public abstract class SocketChannelContext extends ChannelContext // copying. private final int WRITE_LIMIT = 1 << 16; - protected int flushToChannel(ByteBuffer buffer) throws IOException { - int initialPosition = buffer.position(); - ByteBuffer ioBuffer = getSelector().getIoBuffer(); - ioBuffer.limit(Math.min(WRITE_LIMIT, ioBuffer.limit())); - copyBytes(buffer, ioBuffer); - ioBuffer.flip(); - int bytesWritten; - try { - bytesWritten = rawChannel.write(ioBuffer); - } catch (IOException e) { - closeNow = true; - buffer.position(initialPosition); - throw e; - } - buffer.position(initialPosition + bytesWritten); - return bytesWritten; - } - protected int flushToChannel(FlushOperation flushOperation) throws IOException { ByteBuffer ioBuffer = getSelector().getIoBuffer(); @@ -325,12 +288,8 @@ public abstract class SocketChannelContext extends ChannelContext while (continueFlush) { ioBuffer.clear(); ioBuffer.limit(Math.min(WRITE_LIMIT, ioBuffer.limit())); - int j = 0; ByteBuffer[] buffers = flushOperation.getBuffersToWrite(WRITE_LIMIT); - while (j < buffers.length && ioBuffer.remaining() > 0) { - ByteBuffer buffer = buffers[j++]; - copyBytes(buffer, ioBuffer); - } + ByteBufferUtils.copyBytes(buffers, ioBuffer); ioBuffer.flip(); int bytesFlushed; try { @@ -345,12 +304,4 @@ public abstract class SocketChannelContext extends ChannelContext } return totalBytesFlushed; } - - private void copyBytes(ByteBuffer from, ByteBuffer to) { - int nBytesToCopy = Math.min(to.remaining(), from.remaining()); - int initialLimit = from.limit(); - from.limit(from.position() + nBytesToCopy); - to.put(from); - from.limit(initialLimit); - } } diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/utils/ByteBufferUtils.java b/libs/nio/src/main/java/org/elasticsearch/nio/utils/ByteBufferUtils.java new file mode 100644 index 00000000000..0be9806bada --- /dev/null +++ b/libs/nio/src/main/java/org/elasticsearch/nio/utils/ByteBufferUtils.java @@ -0,0 +1,63 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.nio.utils; + +import java.nio.ByteBuffer; + +public final class ByteBufferUtils { + + private ByteBufferUtils() {} + + /** + * Copies bytes from the array of byte buffers into the destination buffer. The number of bytes copied is + * limited by the bytes available to copy and the space remaining in the destination byte buffer. + * + * @param source byte buffers to copy from + * @param destination byte buffer to copy to + * + * @return number of bytes copied + */ + public static long copyBytes(ByteBuffer[] source, ByteBuffer destination) { + long bytesCopied = 0; + for (int i = 0; i < source.length && destination.hasRemaining(); i++) { + ByteBuffer buffer = source[i]; + bytesCopied += copyBytes(buffer, destination); + } + return bytesCopied; + } + + /** + * Copies bytes from source byte buffer into the destination buffer. The number of bytes copied is + * limited by the bytes available to copy and the space remaining in the destination byte buffer. + * + * @param source byte buffer to copy from + * @param destination byte buffer to copy to + * + * @return number of bytes copied + */ + public static int copyBytes(ByteBuffer source, ByteBuffer destination) { + int nBytesToCopy = Math.min(destination.remaining(), source.remaining()); + int initialLimit = source.limit(); + source.limit(source.position() + nBytesToCopy); + destination.put(source); + source.limit(initialLimit); + return nBytesToCopy; + } +} diff --git a/libs/nio/src/test/java/org/elasticsearch/nio/InboundChannelBufferTests.java b/libs/nio/src/test/java/org/elasticsearch/nio/InboundChannelBufferTests.java index f5580430953..49e4fbecec9 100644 --- a/libs/nio/src/test/java/org/elasticsearch/nio/InboundChannelBufferTests.java +++ b/libs/nio/src/test/java/org/elasticsearch/nio/InboundChannelBufferTests.java @@ -19,23 +19,25 @@ package org.elasticsearch.nio; -import org.elasticsearch.common.util.PageCacheRecycler; import org.elasticsearch.test.ESTestCase; import java.nio.ByteBuffer; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Supplier; +import java.util.function.IntFunction; public class InboundChannelBufferTests extends ESTestCase { - private static final int PAGE_SIZE = PageCacheRecycler.PAGE_SIZE_IN_BYTES; - private final Supplier defaultPageSupplier = () -> - new Page(ByteBuffer.allocate(PageCacheRecycler.BYTE_PAGE_SIZE), () -> { - }); + private IntFunction defaultPageAllocator; + + @Override + public void setUp() throws Exception { + super.setUp(); + defaultPageAllocator = (n) -> new Page(ByteBuffer.allocate(n), () -> {}); + } public void testNewBufferNoPages() { - InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageSupplier); + InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageAllocator); assertEquals(0, channelBuffer.getCapacity()); assertEquals(0, channelBuffer.getRemaining()); @@ -43,107 +45,107 @@ public class InboundChannelBufferTests extends ESTestCase { } public void testExpandCapacity() { - InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageSupplier); + InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageAllocator); assertEquals(0, channelBuffer.getCapacity()); assertEquals(0, channelBuffer.getRemaining()); - channelBuffer.ensureCapacity(PAGE_SIZE); + channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE); - assertEquals(PAGE_SIZE, channelBuffer.getCapacity()); - assertEquals(PAGE_SIZE, channelBuffer.getRemaining()); + assertEquals(InboundChannelBuffer.PAGE_SIZE, channelBuffer.getCapacity()); + assertEquals(InboundChannelBuffer.PAGE_SIZE, channelBuffer.getRemaining()); - channelBuffer.ensureCapacity(PAGE_SIZE + 1); + channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE + 1); - assertEquals(PAGE_SIZE * 2, channelBuffer.getCapacity()); - assertEquals(PAGE_SIZE * 2, channelBuffer.getRemaining()); + assertEquals(InboundChannelBuffer.PAGE_SIZE * 2, channelBuffer.getCapacity()); + assertEquals(InboundChannelBuffer.PAGE_SIZE * 2, channelBuffer.getRemaining()); } public void testExpandCapacityMultiplePages() { - InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageSupplier); - channelBuffer.ensureCapacity(PAGE_SIZE); + InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageAllocator); + channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE); - assertEquals(PAGE_SIZE, channelBuffer.getCapacity()); + assertEquals(InboundChannelBuffer.PAGE_SIZE, channelBuffer.getCapacity()); int multiple = randomInt(80); - channelBuffer.ensureCapacity(PAGE_SIZE + ((multiple * PAGE_SIZE) - randomInt(500))); + channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE + ((multiple * InboundChannelBuffer.PAGE_SIZE) - randomInt(500))); - assertEquals(PAGE_SIZE * (multiple + 1), channelBuffer.getCapacity()); - assertEquals(PAGE_SIZE * (multiple + 1), channelBuffer.getRemaining()); + assertEquals(InboundChannelBuffer.PAGE_SIZE * (multiple + 1), channelBuffer.getCapacity()); + assertEquals(InboundChannelBuffer.PAGE_SIZE * (multiple + 1), channelBuffer.getRemaining()); } public void testExpandCapacityRespectsOffset() { - InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageSupplier); - channelBuffer.ensureCapacity(PAGE_SIZE); + InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageAllocator); + channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE); - assertEquals(PAGE_SIZE, channelBuffer.getCapacity()); - assertEquals(PAGE_SIZE, channelBuffer.getRemaining()); + assertEquals(InboundChannelBuffer.PAGE_SIZE, channelBuffer.getCapacity()); + assertEquals(InboundChannelBuffer.PAGE_SIZE, channelBuffer.getRemaining()); int offset = randomInt(300); channelBuffer.release(offset); - assertEquals(PAGE_SIZE - offset, channelBuffer.getCapacity()); - assertEquals(PAGE_SIZE - offset, channelBuffer.getRemaining()); + assertEquals(InboundChannelBuffer.PAGE_SIZE - offset, channelBuffer.getCapacity()); + assertEquals(InboundChannelBuffer.PAGE_SIZE - offset, channelBuffer.getRemaining()); - channelBuffer.ensureCapacity(PAGE_SIZE + 1); + channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE + 1); - assertEquals(PAGE_SIZE * 2 - offset, channelBuffer.getCapacity()); - assertEquals(PAGE_SIZE * 2 - offset, channelBuffer.getRemaining()); + assertEquals(InboundChannelBuffer.PAGE_SIZE * 2 - offset, channelBuffer.getCapacity()); + assertEquals(InboundChannelBuffer.PAGE_SIZE * 2 - offset, channelBuffer.getRemaining()); } public void testIncrementIndex() { - InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageSupplier); - channelBuffer.ensureCapacity(PAGE_SIZE); + InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageAllocator); + channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE); assertEquals(0, channelBuffer.getIndex()); - assertEquals(PAGE_SIZE, channelBuffer.getRemaining()); + assertEquals(InboundChannelBuffer.PAGE_SIZE, channelBuffer.getRemaining()); channelBuffer.incrementIndex(10); assertEquals(10, channelBuffer.getIndex()); - assertEquals(PAGE_SIZE - 10, channelBuffer.getRemaining()); + assertEquals(InboundChannelBuffer.PAGE_SIZE - 10, channelBuffer.getRemaining()); } public void testIncrementIndexWithOffset() { - InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageSupplier); - channelBuffer.ensureCapacity(PAGE_SIZE); + InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageAllocator); + channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE); assertEquals(0, channelBuffer.getIndex()); - assertEquals(PAGE_SIZE, channelBuffer.getRemaining()); + assertEquals(InboundChannelBuffer.PAGE_SIZE, channelBuffer.getRemaining()); channelBuffer.release(10); - assertEquals(PAGE_SIZE - 10, channelBuffer.getRemaining()); + assertEquals(InboundChannelBuffer.PAGE_SIZE - 10, channelBuffer.getRemaining()); channelBuffer.incrementIndex(10); assertEquals(10, channelBuffer.getIndex()); - assertEquals(PAGE_SIZE - 20, channelBuffer.getRemaining()); + assertEquals(InboundChannelBuffer.PAGE_SIZE - 20, channelBuffer.getRemaining()); channelBuffer.release(2); assertEquals(8, channelBuffer.getIndex()); - assertEquals(PAGE_SIZE - 20, channelBuffer.getRemaining()); + assertEquals(InboundChannelBuffer.PAGE_SIZE - 20, channelBuffer.getRemaining()); } public void testReleaseClosesPages() { ConcurrentLinkedQueue queue = new ConcurrentLinkedQueue<>(); - Supplier supplier = () -> { + IntFunction allocator = (n) -> { AtomicBoolean atomicBoolean = new AtomicBoolean(); queue.add(atomicBoolean); - return new Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true)); + return new Page(ByteBuffer.allocate(n), () -> atomicBoolean.set(true)); }; - InboundChannelBuffer channelBuffer = new InboundChannelBuffer(supplier); - channelBuffer.ensureCapacity(PAGE_SIZE * 4); + InboundChannelBuffer channelBuffer = new InboundChannelBuffer(allocator); + channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE * 4); - assertEquals(PAGE_SIZE * 4, channelBuffer.getCapacity()); + assertEquals(InboundChannelBuffer.PAGE_SIZE * 4, channelBuffer.getCapacity()); assertEquals(4, queue.size()); for (AtomicBoolean closedRef : queue) { assertFalse(closedRef.get()); } - channelBuffer.release(2 * PAGE_SIZE); + channelBuffer.release(2 * InboundChannelBuffer.PAGE_SIZE); - assertEquals(PAGE_SIZE * 2, channelBuffer.getCapacity()); + assertEquals(InboundChannelBuffer.PAGE_SIZE * 2, channelBuffer.getCapacity()); assertTrue(queue.poll().get()); assertTrue(queue.poll().get()); @@ -153,13 +155,13 @@ public class InboundChannelBufferTests extends ESTestCase { public void testClose() { ConcurrentLinkedQueue queue = new ConcurrentLinkedQueue<>(); - Supplier supplier = () -> { + IntFunction allocator = (n) -> { AtomicBoolean atomicBoolean = new AtomicBoolean(); queue.add(atomicBoolean); - return new Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true)); + return new Page(ByteBuffer.allocate(n), () -> atomicBoolean.set(true)); }; - InboundChannelBuffer channelBuffer = new InboundChannelBuffer(supplier); - channelBuffer.ensureCapacity(PAGE_SIZE * 4); + InboundChannelBuffer channelBuffer = new InboundChannelBuffer(allocator); + channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE * 4); assertEquals(4, queue.size()); @@ -178,13 +180,13 @@ public class InboundChannelBufferTests extends ESTestCase { public void testCloseRetainedPages() { ConcurrentLinkedQueue queue = new ConcurrentLinkedQueue<>(); - Supplier supplier = () -> { + IntFunction allocator = (n) -> { AtomicBoolean atomicBoolean = new AtomicBoolean(); queue.add(atomicBoolean); - return new Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true)); + return new Page(ByteBuffer.allocate(n), () -> atomicBoolean.set(true)); }; - InboundChannelBuffer channelBuffer = new InboundChannelBuffer(supplier); - channelBuffer.ensureCapacity(PAGE_SIZE * 4); + InboundChannelBuffer channelBuffer = new InboundChannelBuffer(allocator); + channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE * 4); assertEquals(4, queue.size()); @@ -192,7 +194,7 @@ public class InboundChannelBufferTests extends ESTestCase { assertFalse(closedRef.get()); } - Page[] pages = channelBuffer.sliceAndRetainPagesTo(PAGE_SIZE * 2); + Page[] pages = channelBuffer.sliceAndRetainPagesTo(InboundChannelBuffer.PAGE_SIZE * 2); pages[1].close(); @@ -220,10 +222,10 @@ public class InboundChannelBufferTests extends ESTestCase { } public void testAccessByteBuffers() { - InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageSupplier); + InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageAllocator); int pages = randomInt(50) + 5; - channelBuffer.ensureCapacity(pages * PAGE_SIZE); + channelBuffer.ensureCapacity(pages * InboundChannelBuffer.PAGE_SIZE); long capacity = channelBuffer.getCapacity(); diff --git a/libs/nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java b/libs/nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java index baf7abac79d..0040f70df85 100644 --- a/libs/nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java +++ b/libs/nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java @@ -34,8 +34,8 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; import java.util.function.Consumer; +import java.util.function.IntFunction; import java.util.function.Predicate; -import java.util.function.Supplier; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyInt; @@ -285,8 +285,8 @@ public class SocketChannelContextTests extends ESTestCase { when(channel.getRawChannel()).thenReturn(realChannel); when(channel.isOpen()).thenReturn(true); Runnable closer = mock(Runnable.class); - Supplier pageSupplier = () -> new Page(ByteBuffer.allocate(1 << 14), closer); - InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier); + IntFunction pageAllocator = (n) -> new Page(ByteBuffer.allocate(n), closer); + InboundChannelBuffer buffer = new InboundChannelBuffer(pageAllocator); buffer.ensureCapacity(1); TestSocketChannelContext context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, buffer); context.closeFromSelector(); @@ -294,29 +294,6 @@ public class SocketChannelContextTests extends ESTestCase { } } - public void testReadToBufferLimitsToPassedBuffer() throws IOException { - ByteBuffer buffer = ByteBuffer.allocate(10); - when(rawChannel.read(any(ByteBuffer.class))).thenAnswer(completelyFillBufferAnswer()); - - int bytesRead = context.readFromChannel(buffer); - assertEquals(bytesRead, 10); - assertEquals(0, buffer.remaining()); - } - - public void testReadToBufferHandlesIOException() throws IOException { - when(rawChannel.read(any(ByteBuffer.class))).thenThrow(new IOException()); - - expectThrows(IOException.class, () -> context.readFromChannel(ByteBuffer.allocate(10))); - assertTrue(context.closeNow()); - } - - public void testReadToBufferHandlesEOF() throws IOException { - when(rawChannel.read(any(ByteBuffer.class))).thenReturn(-1); - - context.readFromChannel(ByteBuffer.allocate(10)); - assertTrue(context.closeNow()); - } - public void testReadToChannelBufferWillReadAsMuchAsIOBufferAllows() throws IOException { when(rawChannel.read(any(ByteBuffer.class))).thenAnswer(completelyFillBufferAnswer()); @@ -344,33 +321,6 @@ public class SocketChannelContextTests extends ESTestCase { assertEquals(0, channelBuffer.getIndex()); } - public void testFlushBufferHandlesPartialFlush() throws IOException { - int bytesToConsume = 3; - when(rawChannel.write(any(ByteBuffer.class))).thenAnswer(consumeBufferAnswer(bytesToConsume)); - - ByteBuffer buffer = ByteBuffer.allocate(10); - context.flushToChannel(buffer); - assertEquals(10 - bytesToConsume, buffer.remaining()); - } - - public void testFlushBufferHandlesFullFlush() throws IOException { - int bytesToConsume = 10; - when(rawChannel.write(any(ByteBuffer.class))).thenAnswer(consumeBufferAnswer(bytesToConsume)); - - ByteBuffer buffer = ByteBuffer.allocate(10); - context.flushToChannel(buffer); - assertEquals(0, buffer.remaining()); - } - - public void testFlushBufferHandlesIOException() throws IOException { - when(rawChannel.write(any(ByteBuffer.class))).thenThrow(new IOException()); - - ByteBuffer buffer = ByteBuffer.allocate(10); - expectThrows(IOException.class, () -> context.flushToChannel(buffer)); - assertTrue(context.closeNow()); - assertEquals(10, buffer.remaining()); - } - public void testFlushBuffersHandlesZeroFlush() throws IOException { when(rawChannel.write(any(ByteBuffer.class))).thenAnswer(consumeBufferAnswer(0)); @@ -456,22 +406,14 @@ public class SocketChannelContextTests extends ESTestCase { @Override public int read() throws IOException { - if (randomBoolean()) { - InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance(); - return readFromChannel(channelBuffer); - } else { - return readFromChannel(ByteBuffer.allocate(10)); - } + InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance(); + return readFromChannel(channelBuffer); } @Override public void flushChannel() throws IOException { - if (randomBoolean()) { - ByteBuffer[] byteBuffers = {ByteBuffer.allocate(10)}; - flushToChannel(new FlushOperation(byteBuffers, (v, e) -> {})); - } else { - flushToChannel(ByteBuffer.allocate(10)); - } + ByteBuffer[] byteBuffers = {ByteBuffer.allocate(10)}; + flushToChannel(new FlushOperation(byteBuffers, (v, e) -> {})); } @Override diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java index 2730cb6d3a9..fa0f3e9572c 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java @@ -25,7 +25,6 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.common.Strings; import org.elasticsearch.common.network.NetworkService; -import org.elasticsearch.common.recycler.Recycler; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.SettingsException; import org.elasticsearch.common.unit.ByteSizeValue; @@ -43,16 +42,15 @@ import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioGroup; import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioSocketChannel; -import org.elasticsearch.nio.Page; import org.elasticsearch.nio.ServerChannelContext; import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.rest.RestUtils; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.nio.NioGroupFactory; +import org.elasticsearch.transport.nio.PageAllocator; import java.io.IOException; import java.net.InetSocketAddress; -import java.nio.ByteBuffer; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; import java.util.Arrays; @@ -80,8 +78,8 @@ import static org.elasticsearch.http.nio.cors.NioCorsHandler.ANY_ORIGIN; public class NioHttpServerTransport extends AbstractHttpServerTransport { private static final Logger logger = LogManager.getLogger(NioHttpServerTransport.class); - protected final PageCacheRecycler pageCacheRecycler; protected final NioCorsConfig corsConfig; + protected final PageAllocator pageAllocator; private final NioGroupFactory nioGroupFactory; protected final boolean tcpNoDelay; @@ -97,7 +95,7 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport { PageCacheRecycler pageCacheRecycler, ThreadPool threadPool, NamedXContentRegistry xContentRegistry, Dispatcher dispatcher, NioGroupFactory nioGroupFactory) { super(settings, networkService, bigArrays, threadPool, xContentRegistry, dispatcher); - this.pageCacheRecycler = pageCacheRecycler; + this.pageAllocator = new PageAllocator(pageCacheRecycler); this.nioGroupFactory = nioGroupFactory; ByteSizeValue maxChunkSize = SETTING_HTTP_MAX_CHUNK_SIZE.get(settings); @@ -206,15 +204,11 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport { @Override public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException { NioHttpChannel httpChannel = new NioHttpChannel(channel); - java.util.function.Supplier pageSupplier = () -> { - Recycler.V bytes = pageCacheRecycler.bytePage(false); - return new Page(ByteBuffer.wrap(bytes.v()), bytes::close); - }; HttpReadWriteHandler httpReadWritePipeline = new HttpReadWriteHandler(httpChannel,NioHttpServerTransport.this, handlingSettings, corsConfig, selector.getTaskScheduler(), threadPool::relativeTimeInMillis); Consumer exceptionHandler = (e) -> onException(httpChannel, e); SocketChannelContext context = new BytesChannelContext(httpChannel, selector, exceptionHandler, httpReadWritePipeline, - new InboundChannelBuffer(pageSupplier)); + new InboundChannelBuffer(pageAllocator)); httpChannel.setContext(context); return httpChannel; } diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java index 17dc6c41baa..a39098a3d59 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java @@ -26,7 +26,6 @@ import org.elasticsearch.Version; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkService; -import org.elasticsearch.common.recycler.Recycler; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.PageCacheRecycler; import org.elasticsearch.indices.breaker.CircuitBreakerService; @@ -36,20 +35,17 @@ import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioGroup; import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioSocketChannel; -import org.elasticsearch.nio.Page; import org.elasticsearch.nio.ServerChannelContext; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TcpTransport; import java.io.IOException; import java.net.InetSocketAddress; -import java.nio.ByteBuffer; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; import java.util.concurrent.ConcurrentMap; import java.util.function.Consumer; import java.util.function.Function; -import java.util.function.Supplier; import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentMap; @@ -57,6 +53,7 @@ public class NioTransport extends TcpTransport { private static final Logger logger = LogManager.getLogger(NioTransport.class); + protected final PageAllocator pageAllocator; private final ConcurrentMap profileToChannelFactory = newConcurrentMap(); private final NioGroupFactory groupFactory; private volatile NioGroup nioGroup; @@ -66,6 +63,7 @@ public class NioTransport extends TcpTransport { PageCacheRecycler pageCacheRecycler, NamedWriteableRegistry namedWriteableRegistry, CircuitBreakerService circuitBreakerService, NioGroupFactory groupFactory) { super(settings, version, threadPool, pageCacheRecycler, circuitBreakerService, namedWriteableRegistry, networkService); + this.pageAllocator = new PageAllocator(pageCacheRecycler); this.groupFactory = groupFactory; } @@ -158,14 +156,10 @@ public class NioTransport extends TcpTransport { @Override public NioTcpChannel createChannel(NioSelector selector, SocketChannel channel) { NioTcpChannel nioChannel = new NioTcpChannel(isClient == false, profileName, channel); - Supplier pageSupplier = () -> { - Recycler.V bytes = pageCacheRecycler.bytePage(false); - return new Page(ByteBuffer.wrap(bytes.v()), bytes::close); - }; TcpReadWriteHandler readWriteHandler = new TcpReadWriteHandler(nioChannel, NioTransport.this); Consumer exceptionHandler = (e) -> onException(nioChannel, e); BytesChannelContext context = new BytesChannelContext(nioChannel, selector, exceptionHandler, readWriteHandler, - new InboundChannelBuffer(pageSupplier)); + new InboundChannelBuffer(pageAllocator)); nioChannel.setContext(context); return nioChannel; } diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/PageAllocator.java b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/PageAllocator.java new file mode 100644 index 00000000000..bf9f3ffc891 --- /dev/null +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/PageAllocator.java @@ -0,0 +1,48 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio; + +import org.elasticsearch.common.recycler.Recycler; +import org.elasticsearch.common.util.PageCacheRecycler; +import org.elasticsearch.nio.Page; + +import java.nio.ByteBuffer; +import java.util.function.IntFunction; + +public class PageAllocator implements IntFunction { + + private static final int RECYCLE_LOWER_THRESHOLD = PageCacheRecycler.BYTE_PAGE_SIZE / 2; + + private final PageCacheRecycler recycler; + + public PageAllocator(PageCacheRecycler recycler) { + this.recycler = recycler; + } + + @Override + public Page apply(int length) { + if (length >= RECYCLE_LOWER_THRESHOLD && length <= PageCacheRecycler.BYTE_PAGE_SIZE){ + Recycler.V bytePage = recycler.bytePage(false); + return new Page(ByteBuffer.wrap(bytePage.v(), 0, length), bytePage::close); + } else { + return new Page(ByteBuffer.allocate(length), () -> {}); + } + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java index abb92979f8d..39316ca9192 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java @@ -37,8 +37,8 @@ import org.elasticsearch.nio.BytesChannelContext; import org.elasticsearch.nio.BytesWriteHandler; import org.elasticsearch.nio.ChannelFactory; import org.elasticsearch.nio.InboundChannelBuffer; -import org.elasticsearch.nio.NioSelectorGroup; import org.elasticsearch.nio.NioSelector; +import org.elasticsearch.nio.NioSelectorGroup; import org.elasticsearch.nio.NioServerSocketChannel; import org.elasticsearch.nio.NioSocketChannel; import org.elasticsearch.nio.Page; @@ -61,7 +61,7 @@ import java.util.HashSet; import java.util.Set; import java.util.concurrent.ConcurrentMap; import java.util.function.Consumer; -import java.util.function.Supplier; +import java.util.function.IntFunction; import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentMap; import static org.elasticsearch.common.util.concurrent.EsExecutors.daemonThreadFactory; @@ -192,9 +192,13 @@ public class MockNioTransport extends TcpTransport { @Override public MockSocketChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException { MockSocketChannel nioChannel = new MockSocketChannel(isClient == false, profileName, channel); - Supplier pageSupplier = () -> { - Recycler.V bytes = pageCacheRecycler.bytePage(false); - return new Page(ByteBuffer.wrap(bytes.v()), bytes::close); + IntFunction pageSupplier = (length) -> { + if (length > PageCacheRecycler.BYTE_PAGE_SIZE) { + return new Page(ByteBuffer.allocate(length), () -> {}); + } else { + Recycler.V bytes = pageCacheRecycler.bytePage(false); + return new Page(ByteBuffer.wrap(bytes.v(), 0, length), bytes::close); + } }; MockTcpReadWriteHandler readWriteHandler = new MockTcpReadWriteHandler(nioChannel, MockNioTransport.this); BytesChannelContext context = new BytesChannelContext(nioChannel, selector, (e) -> exceptionCaught(nioChannel, e), diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java index 9372cb1ec54..de1259765b9 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java @@ -36,19 +36,22 @@ public final class SSLChannelContext extends SocketChannelContext { private static final Runnable DEFAULT_TIMEOUT_CANCELLER = () -> {}; private final SSLDriver sslDriver; + private final InboundChannelBuffer networkReadBuffer; private final LinkedList encryptedFlushes = new LinkedList<>(); private Runnable closeTimeoutCanceller = DEFAULT_TIMEOUT_CANCELLER; SSLChannelContext(NioSocketChannel channel, NioSelector selector, Consumer exceptionHandler, SSLDriver sslDriver, - ReadWriteHandler readWriteHandler, InboundChannelBuffer channelBuffer) { - this(channel, selector, exceptionHandler, sslDriver, readWriteHandler, channelBuffer, ALWAYS_ALLOW_CHANNEL); + ReadWriteHandler readWriteHandler, InboundChannelBuffer applicationBuffer) { + this(channel, selector, exceptionHandler, sslDriver, readWriteHandler, InboundChannelBuffer.allocatingInstance(), + applicationBuffer, ALWAYS_ALLOW_CHANNEL); } SSLChannelContext(NioSocketChannel channel, NioSelector selector, Consumer exceptionHandler, SSLDriver sslDriver, - ReadWriteHandler readWriteHandler, InboundChannelBuffer channelBuffer, + ReadWriteHandler readWriteHandler, InboundChannelBuffer networkReadBuffer, InboundChannelBuffer channelBuffer, Predicate allowChannelPredicate) { super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer, allowChannelPredicate); this.sslDriver = sslDriver; + this.networkReadBuffer = networkReadBuffer; } @Override @@ -157,12 +160,12 @@ public final class SSLChannelContext extends SocketChannelContext { if (closeNow()) { return bytesRead; } - bytesRead = readFromChannel(sslDriver.getNetworkReadBuffer()); + bytesRead = readFromChannel(networkReadBuffer); if (bytesRead == 0) { return bytesRead; } - sslDriver.read(channelBuffer); + sslDriver.read(networkReadBuffer, channelBuffer); handleReadBytes(); // It is possible that a read call produced non-application bytes to flush @@ -201,7 +204,7 @@ public final class SSLChannelContext extends SocketChannelContext { getSelector().executeFailedListener(encryptedFlush.getListener(), new ClosedChannelException()); } encryptedFlushes.clear(); - IOUtils.close(super::closeFromSelector, sslDriver::close); + IOUtils.close(super::closeFromSelector, networkReadBuffer::close, sslDriver::close); } } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLDriver.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLDriver.java index bc112dd3a60..e54bc9fa16e 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLDriver.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLDriver.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.security.transport.nio; import org.elasticsearch.nio.FlushOperation; import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.Page; +import org.elasticsearch.nio.utils.ByteBufferUtils; import org.elasticsearch.nio.utils.ExceptionsHelper; import javax.net.ssl.SSLEngine; @@ -16,6 +17,7 @@ import javax.net.ssl.SSLException; import javax.net.ssl.SSLSession; import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.function.IntFunction; /** * SSLDriver is a class that wraps the {@link SSLEngine} and attempts to simplify the API. The basic usage is @@ -27,9 +29,9 @@ import java.util.ArrayList; * application to be written to the wire. * * Handling reads from a channel with this class is very simple. When data has been read, call - * {@link #read(InboundChannelBuffer)}. If the data is application data, it will be decrypted and placed into - * the buffer passed as an argument. Otherwise, it will be consumed internally and advance the SSL/TLS close - * or handshake process. + * {@link #read(InboundChannelBuffer, InboundChannelBuffer)}. If the data is application data, it will be + * decrypted and placed into the application buffer passed as an argument. Otherwise, it will be consumed + * internally and advance the SSL/TLS close or handshake process. * * Producing writes for a channel is more complicated. The method {@link #needsNonApplicationWrite()} can be * called to determine if this driver needs to produce more data to advance the handshake or close process. @@ -54,21 +56,22 @@ public class SSLDriver implements AutoCloseable { private static final FlushOperation EMPTY_FLUSH_OPERATION = new FlushOperation(EMPTY_BUFFERS, (r, t) -> {}); private final SSLEngine engine; - // TODO: When the bytes are actually recycled, we need to test that they are released on driver close - private final SSLOutboundBuffer outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n))); + private final IntFunction pageAllocator; + private final SSLOutboundBuffer outboundBuffer; + private Page networkReadPage; private final boolean isClientMode; // This should only be accessed by the network thread associated with this channel, so nothing needs to // be volatile. private Mode currentMode = new HandshakeMode(); - private ByteBuffer networkReadBuffer; private int packetSize; - public SSLDriver(SSLEngine engine, boolean isClientMode) { + public SSLDriver(SSLEngine engine, IntFunction pageAllocator, boolean isClientMode) { this.engine = engine; + this.pageAllocator = pageAllocator; + this.outboundBuffer = new SSLOutboundBuffer(pageAllocator); this.isClientMode = isClientMode; SSLSession session = engine.getSession(); packetSize = session.getPacketBufferSize(); - this.networkReadBuffer = ByteBuffer.allocate(packetSize); } public void init() throws SSLException { @@ -106,22 +109,25 @@ public class SSLDriver implements AutoCloseable { return currentMode.isHandshake(); } - public ByteBuffer getNetworkReadBuffer() { - return networkReadBuffer; - } - public SSLOutboundBuffer getOutboundBuffer() { return outboundBuffer; } - public void read(InboundChannelBuffer buffer) throws SSLException { - Mode modePriorToRead; - do { - modePriorToRead = currentMode; - currentMode.read(buffer); - // If we switched modes we want to read again as there might be unhandled bytes that need to be - // handled by the new mode. - } while (modePriorToRead != currentMode); + public void read(InboundChannelBuffer encryptedBuffer, InboundChannelBuffer applicationBuffer) throws SSLException { + networkReadPage = pageAllocator.apply(packetSize); + try { + Mode modePriorToRead; + do { + modePriorToRead = currentMode; + currentMode.read(encryptedBuffer, applicationBuffer); + // It is possible that we received multiple SSL packets from the network since the last read. + // If one of those packets causes us to change modes (such as finished handshaking), we need + // to call read in the new mode to handle the remaining packets. + } while (modePriorToRead != currentMode); + } finally { + networkReadPage.close(); + networkReadPage = null; + } } public boolean readyForApplicationWrites() { @@ -171,27 +177,34 @@ public class SSLDriver implements AutoCloseable { ExceptionsHelper.rethrowAndSuppress(closingExceptions); } - private SSLEngineResult unwrap(InboundChannelBuffer buffer) throws SSLException { + private SSLEngineResult unwrap(InboundChannelBuffer networkBuffer, InboundChannelBuffer applicationBuffer) throws SSLException { while (true) { - SSLEngineResult result = engine.unwrap(networkReadBuffer, buffer.sliceBuffersFrom(buffer.getIndex())); - buffer.incrementIndex(result.bytesProduced()); + ensureApplicationBufferSize(applicationBuffer); + ByteBuffer networkReadBuffer = networkReadPage.byteBuffer(); + networkReadBuffer.clear(); + ByteBufferUtils.copyBytes(networkBuffer.sliceBuffersTo(Math.min(networkBuffer.getIndex(), packetSize)), networkReadBuffer); + networkReadBuffer.flip(); + SSLEngineResult result = engine.unwrap(networkReadBuffer, applicationBuffer.sliceBuffersFrom(applicationBuffer.getIndex())); + networkBuffer.release(result.bytesConsumed()); + applicationBuffer.incrementIndex(result.bytesProduced()); switch (result.getStatus()) { case OK: - networkReadBuffer.compact(); return result; case BUFFER_UNDERFLOW: // There is not enough space in the network buffer for an entire SSL packet. Compact the // current data and expand the buffer if necessary. - int currentCapacity = networkReadBuffer.capacity(); - ensureNetworkReadBufferSize(); - if (currentCapacity == networkReadBuffer.capacity()) { - networkReadBuffer.compact(); + packetSize = engine.getSession().getPacketBufferSize(); + if (networkReadPage.byteBuffer().capacity() < packetSize) { + networkReadPage.close(); + networkReadPage = pageAllocator.apply(packetSize); + } else { + return result; } - return result; + break; case BUFFER_OVERFLOW: // There is not enough space in the application buffer for the decrypted message. Expand // the application buffer to ensure that it has enough space. - ensureApplicationBufferSize(buffer); + ensureApplicationBufferSize(applicationBuffer); break; case CLOSED: assert engine.isInboundDone() : "We received close_notify so read should be done"; @@ -254,15 +267,6 @@ public class SSLDriver implements AutoCloseable { } } - private void ensureNetworkReadBufferSize() { - packetSize = engine.getSession().getPacketBufferSize(); - if (networkReadBuffer.capacity() < packetSize) { - ByteBuffer newBuffer = ByteBuffer.allocate(packetSize); - networkReadBuffer.flip(); - newBuffer.put(networkReadBuffer); - } - } - // There are three potential modes for the driver to be in - HANDSHAKE, APPLICATION, or CLOSE. HANDSHAKE // is the initial mode. During this mode data that is read and written will be related to the TLS // handshake process. Application related data cannot be encrypted until the handshake is complete. From @@ -282,7 +286,7 @@ public class SSLDriver implements AutoCloseable { private interface Mode { - void read(InboundChannelBuffer buffer) throws SSLException; + void read(InboundChannelBuffer encryptedBuffer, InboundChannelBuffer applicationBuffer) throws SSLException; int write(FlushOperation applicationBytes) throws SSLException; @@ -342,13 +346,11 @@ public class SSLDriver implements AutoCloseable { } @Override - public void read(InboundChannelBuffer buffer) throws SSLException { - ensureApplicationBufferSize(buffer); + public void read(InboundChannelBuffer encryptedBuffer, InboundChannelBuffer applicationBuffer) throws SSLException { boolean continueUnwrap = true; - while (continueUnwrap && networkReadBuffer.position() > 0) { - networkReadBuffer.flip(); + while (continueUnwrap && encryptedBuffer.getIndex() > 0) { try { - SSLEngineResult result = unwrap(buffer); + SSLEngineResult result = unwrap(encryptedBuffer, applicationBuffer); handshakeStatus = result.getHandshakeStatus(); handshake(); // If we are done handshaking we should exit the handshake read @@ -430,12 +432,10 @@ public class SSLDriver implements AutoCloseable { private class ApplicationMode implements Mode { @Override - public void read(InboundChannelBuffer buffer) throws SSLException { - ensureApplicationBufferSize(buffer); + public void read(InboundChannelBuffer encryptedBuffer, InboundChannelBuffer applicationBuffer) throws SSLException { boolean continueUnwrap = true; - while (continueUnwrap && networkReadBuffer.position() > 0) { - networkReadBuffer.flip(); - SSLEngineResult result = unwrap(buffer); + while (continueUnwrap && encryptedBuffer.getIndex() > 0) { + SSLEngineResult result = unwrap(encryptedBuffer, applicationBuffer); boolean renegotiationRequested = result.getStatus() != SSLEngineResult.Status.CLOSED && maybeRenegotiation(result.getHandshakeStatus()); continueUnwrap = result.bytesProduced() > 0 && renegotiationRequested == false; @@ -515,7 +515,7 @@ public class SSLDriver implements AutoCloseable { } @Override - public void read(InboundChannelBuffer buffer) throws SSLException { + public void read(InboundChannelBuffer encryptedBuffer, InboundChannelBuffer applicationBuffer) throws SSLException { if (needToReceiveClose == false) { // There is an issue where receiving handshake messages after initiating the close process // can place the SSLEngine back into handshaking mode. In order to handle this, if we @@ -524,11 +524,9 @@ public class SSLDriver implements AutoCloseable { return; } - ensureApplicationBufferSize(buffer); boolean continueUnwrap = true; - while (continueUnwrap && networkReadBuffer.position() > 0) { - networkReadBuffer.flip(); - SSLEngineResult result = unwrap(buffer); + while (continueUnwrap && encryptedBuffer.getIndex() > 0) { + SSLEngineResult result = unwrap(encryptedBuffer, applicationBuffer); continueUnwrap = result.bytesProduced() > 0 || result.bytesConsumed() > 0; } if (engine.isInboundDone()) { diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioHttpServerTransport.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioHttpServerTransport.java index b65f29eb951..ddf465f81d9 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioHttpServerTransport.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioHttpServerTransport.java @@ -8,7 +8,6 @@ package org.elasticsearch.xpack.security.transport.nio; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.common.network.NetworkService; -import org.elasticsearch.common.recycler.Recycler; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.PageCacheRecycler; @@ -22,7 +21,6 @@ import org.elasticsearch.nio.ChannelFactory; import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioSocketChannel; -import org.elasticsearch.nio.Page; import org.elasticsearch.nio.ServerChannelContext; import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.threadpool.ThreadPool; @@ -35,11 +33,9 @@ import org.elasticsearch.xpack.security.transport.filter.IPFilter; import javax.net.ssl.SSLEngine; import java.io.IOException; import java.net.InetSocketAddress; -import java.nio.ByteBuffer; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; import java.util.function.Consumer; -import java.util.function.Supplier; import static org.elasticsearch.xpack.core.XPackSettings.HTTP_SSL_ENABLED; @@ -93,13 +89,9 @@ public class SecurityNioHttpServerTransport extends NioHttpServerTransport { @Override public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException { NioHttpChannel httpChannel = new NioHttpChannel(channel); - Supplier pageSupplier = () -> { - Recycler.V bytes = pageCacheRecycler.bytePage(false); - return new Page(ByteBuffer.wrap(bytes.v()), bytes::close); - }; HttpReadWriteHandler httpHandler = new HttpReadWriteHandler(httpChannel,SecurityNioHttpServerTransport.this, handlingSettings, corsConfig, selector.getTaskScheduler(), threadPool::relativeTimeInNanos); - InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier); + InboundChannelBuffer networkBuffer = new InboundChannelBuffer(pageAllocator); Consumer exceptionHandler = (e) -> securityExceptionHandler.accept(httpChannel, e); SocketChannelContext context; @@ -113,10 +105,12 @@ public class SecurityNioHttpServerTransport extends NioHttpServerTransport { } else { sslEngine = sslService.createSSLEngine(sslConfiguration, null, -1); } - SSLDriver sslDriver = new SSLDriver(sslEngine, false); - context = new SSLChannelContext(httpChannel, selector, exceptionHandler, sslDriver, httpHandler, buffer, nioIpFilter); + SSLDriver sslDriver = new SSLDriver(sslEngine, pageAllocator, false); + InboundChannelBuffer applicationBuffer = new InboundChannelBuffer(pageAllocator); + context = new SSLChannelContext(httpChannel, selector, exceptionHandler, sslDriver, httpHandler, networkBuffer, + applicationBuffer, nioIpFilter); } else { - context = new BytesChannelContext(httpChannel, selector, exceptionHandler, httpHandler, buffer, nioIpFilter); + context = new BytesChannelContext(httpChannel, selector, exceptionHandler, httpHandler, networkBuffer, nioIpFilter); } httpChannel.setContext(context); diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java index d3f92a2575f..cf32809333e 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java @@ -12,7 +12,6 @@ import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkService; -import org.elasticsearch.common.recycler.Recycler; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.PageCacheRecycler; import org.elasticsearch.indices.breaker.CircuitBreakerService; @@ -21,7 +20,6 @@ import org.elasticsearch.nio.ChannelFactory; import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioSocketChannel; -import org.elasticsearch.nio.Page; import org.elasticsearch.nio.ServerChannelContext; import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.threadpool.ThreadPool; @@ -45,14 +43,12 @@ import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLParameters; import java.io.IOException; import java.net.InetSocketAddress; -import java.nio.ByteBuffer; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; import java.util.Collections; import java.util.Map; import java.util.function.Consumer; import java.util.function.Function; -import java.util.function.Supplier; import static org.elasticsearch.xpack.core.security.SecurityField.setting; @@ -156,20 +152,18 @@ public class SecurityNioTransport extends NioTransport { @Override public NioTcpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException { NioTcpChannel nioChannel = new NioTcpChannel(isClient == false, profileName, channel); - Supplier pageSupplier = () -> { - Recycler.V bytes = pageCacheRecycler.bytePage(false); - return new Page(ByteBuffer.wrap(bytes.v()), bytes::close); - }; TcpReadWriteHandler readWriteHandler = new TcpReadWriteHandler(nioChannel, SecurityNioTransport.this); - InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier); + InboundChannelBuffer networkBuffer = new InboundChannelBuffer(pageAllocator); Consumer exceptionHandler = (e) -> onException(nioChannel, e); SocketChannelContext context; if (sslEnabled) { - SSLDriver sslDriver = new SSLDriver(createSSLEngine(channel), isClient); - context = new SSLChannelContext(nioChannel, selector, exceptionHandler, sslDriver, readWriteHandler, buffer, ipFilter); + SSLDriver sslDriver = new SSLDriver(createSSLEngine(channel), pageAllocator, isClient); + InboundChannelBuffer applicationBuffer = new InboundChannelBuffer(pageAllocator); + context = new SSLChannelContext(nioChannel, selector, exceptionHandler, sslDriver, readWriteHandler, networkBuffer, + applicationBuffer, ipFilter); } else { - context = new BytesChannelContext(nioChannel, selector, exceptionHandler, readWriteHandler, buffer, ipFilter); + context = new BytesChannelContext(nioChannel, selector, exceptionHandler, readWriteHandler, networkBuffer, ipFilter); } nioChannel.setContext(context); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java index dcccb23f1f6..6a380a8fab2 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java @@ -52,7 +52,6 @@ public class SSLChannelContextTests extends ESTestCase { private BiConsumer listener; private Consumer exceptionHandler; private SSLDriver sslDriver; - private ByteBuffer readBuffer = ByteBuffer.allocate(1 << 14); private int messageLength; @Before @@ -76,7 +75,6 @@ public class SSLChannelContextTests extends ESTestCase { when(selector.isOnCurrentThread()).thenReturn(true); when(selector.getTaskScheduler()).thenReturn(nioTimer); - when(sslDriver.getNetworkReadBuffer()).thenReturn(readBuffer); when(sslDriver.getOutboundBuffer()).thenReturn(outboundBuffer); ByteBuffer buffer = ByteBuffer.allocate(1 << 14); when(selector.getIoBuffer()).thenAnswer(invocationOnMock -> { @@ -88,8 +86,12 @@ public class SSLChannelContextTests extends ESTestCase { public void testSuccessfulRead() throws IOException { byte[] bytes = createMessage(messageLength); - when(rawChannel.read(any(ByteBuffer.class))).thenReturn(bytes.length); - doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer); + when(rawChannel.read(any(ByteBuffer.class))).thenAnswer(invocationOnMock -> { + ByteBuffer buffer = (ByteBuffer) invocationOnMock.getArguments()[0]; + buffer.put(bytes); + return bytes.length; + }); + doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(any(InboundChannelBuffer.class), eq(channelBuffer)); when(readConsumer.apply(channelBuffer)).thenReturn(messageLength, 0); @@ -103,8 +105,12 @@ public class SSLChannelContextTests extends ESTestCase { public void testMultipleReadsConsumed() throws IOException { byte[] bytes = createMessage(messageLength * 2); - when(rawChannel.read(any(ByteBuffer.class))).thenReturn(bytes.length); - doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer); + when(rawChannel.read(any(ByteBuffer.class))).thenAnswer(invocationOnMock -> { + ByteBuffer buffer = (ByteBuffer) invocationOnMock.getArguments()[0]; + buffer.put(bytes); + return bytes.length; + }); + doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(any(InboundChannelBuffer.class), eq(channelBuffer)); when(readConsumer.apply(channelBuffer)).thenReturn(messageLength, messageLength, 0); @@ -118,8 +124,12 @@ public class SSLChannelContextTests extends ESTestCase { public void testPartialRead() throws IOException { byte[] bytes = createMessage(messageLength); - when(rawChannel.read(any(ByteBuffer.class))).thenReturn(bytes.length); - doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer); + when(rawChannel.read(any(ByteBuffer.class))).thenAnswer(invocationOnMock -> { + ByteBuffer buffer = (ByteBuffer) invocationOnMock.getArguments()[0]; + buffer.put(bytes); + return bytes.length; + }); + doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(any(InboundChannelBuffer.class), eq(channelBuffer)); when(readConsumer.apply(channelBuffer)).thenReturn(0); @@ -424,12 +434,12 @@ public class SSLChannelContextTests extends ESTestCase { private Answer getReadAnswerForBytes(byte[] bytes) { return invocationOnMock -> { - InboundChannelBuffer buffer = (InboundChannelBuffer) invocationOnMock.getArguments()[0]; - buffer.ensureCapacity(buffer.getIndex() + bytes.length); - ByteBuffer[] buffers = buffer.sliceBuffersFrom(buffer.getIndex()); + InboundChannelBuffer appBuffer = (InboundChannelBuffer) invocationOnMock.getArguments()[1]; + appBuffer.ensureCapacity(appBuffer.getIndex() + bytes.length); + ByteBuffer[] buffers = appBuffer.sliceBuffersFrom(appBuffer.getIndex()); assert buffers[0].remaining() > bytes.length; buffers[0].put(bytes); - buffer.incrementIndex(bytes.length); + appBuffer.incrementIndex(bytes.length); return bytes.length; }; } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLDriverTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLDriverTests.java index 376c4e1e99a..fba6db47c1b 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLDriverTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLDriverTests.java @@ -26,14 +26,16 @@ import java.security.SecureRandom; import java.util.Arrays; import java.util.Collections; import java.util.List; -import java.util.function.Supplier; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.IntFunction; public class SSLDriverTests extends ESTestCase { - private final Supplier pageSupplier = () -> new Page(ByteBuffer.allocate(1 << 14), () -> {}); - private InboundChannelBuffer serverBuffer = new InboundChannelBuffer(pageSupplier); - private InboundChannelBuffer clientBuffer = new InboundChannelBuffer(pageSupplier); - private InboundChannelBuffer genericBuffer = new InboundChannelBuffer(pageSupplier); + private final IntFunction pageAllocator = (n) -> new Page(ByteBuffer.allocate(n), () -> {}); + + private final InboundChannelBuffer networkReadBuffer = new InboundChannelBuffer(pageAllocator); + private final InboundChannelBuffer applicationBuffer = new InboundChannelBuffer(pageAllocator); + private final AtomicInteger openPages = new AtomicInteger(0); public void testPingPongAndClose() throws Exception { SSLContext sslContext = getSSLContext(); @@ -44,19 +46,36 @@ public class SSLDriverTests extends ESTestCase { handshake(clientDriver, serverDriver); ByteBuffer[] buffers = {ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8))}; - sendAppData(clientDriver, serverDriver, buffers); - serverDriver.read(serverBuffer); - assertEquals(ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8)), serverBuffer.sliceBuffersTo(4)[0]); + sendAppData(clientDriver, buffers); + serverDriver.read(networkReadBuffer, applicationBuffer); + assertEquals(ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8)), applicationBuffer.sliceBuffersTo(4)[0]); + applicationBuffer.release(4); ByteBuffer[] buffers2 = {ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8))}; - sendAppData(serverDriver, clientDriver, buffers2); - clientDriver.read(clientBuffer); - assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), clientBuffer.sliceBuffersTo(4)[0]); + sendAppData(serverDriver, buffers2); + clientDriver.read(networkReadBuffer, applicationBuffer); + assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), applicationBuffer.sliceBuffersTo(4)[0]); + applicationBuffer.release(4); assertFalse(clientDriver.needsNonApplicationWrite()); normalClose(clientDriver, serverDriver); } + public void testDataStoredInOutboundBufferIsClosed() throws Exception { + SSLContext sslContext = getSSLContext(); + + SSLDriver clientDriver = getDriver(sslContext.createSSLEngine(), true); + SSLDriver serverDriver = getDriver(sslContext.createSSLEngine(), false); + + handshake(clientDriver, serverDriver); + + ByteBuffer[] buffers = {ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8))}; + serverDriver.write(new FlushOperation(buffers, (v, e) -> {})); + + expectThrows(SSLException.class, serverDriver::close); + assertEquals(0, openPages.get()); + } + public void testRenegotiate() throws Exception { SSLContext sslContext = getSSLContext(); @@ -73,9 +92,10 @@ public class SSLDriverTests extends ESTestCase { handshake(clientDriver, serverDriver); ByteBuffer[] buffers = {ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8))}; - sendAppData(clientDriver, serverDriver, buffers); - serverDriver.read(serverBuffer); - assertEquals(ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8)), serverBuffer.sliceBuffersTo(4)[0]); + sendAppData(clientDriver, buffers); + serverDriver.read(networkReadBuffer, applicationBuffer); + assertEquals(ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8)), applicationBuffer.sliceBuffersTo(4)[0]); + applicationBuffer.release(4); clientDriver.renegotiate(); assertTrue(clientDriver.isHandshaking()); @@ -83,17 +103,20 @@ public class SSLDriverTests extends ESTestCase { // This tests that the client driver can still receive data based on the prior handshake ByteBuffer[] buffers2 = {ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8))}; - sendAppData(serverDriver, clientDriver, buffers2); - clientDriver.read(clientBuffer); - assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), clientBuffer.sliceBuffersTo(4)[0]); + sendAppData(serverDriver, buffers2); + clientDriver.read(networkReadBuffer, applicationBuffer); + assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), applicationBuffer.sliceBuffersTo(4)[0]); + applicationBuffer.release(4); handshake(clientDriver, serverDriver, true); - sendAppData(clientDriver, serverDriver, buffers); - serverDriver.read(serverBuffer); - assertEquals(ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8)), serverBuffer.sliceBuffersTo(4)[0]); - sendAppData(serverDriver, clientDriver, buffers2); - clientDriver.read(clientBuffer); - assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), clientBuffer.sliceBuffersTo(4)[0]); + sendAppData(clientDriver, buffers); + serverDriver.read(networkReadBuffer, applicationBuffer); + assertEquals(ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8)), applicationBuffer.sliceBuffersTo(4)[0]); + applicationBuffer.release(4); + sendAppData(serverDriver, buffers2); + clientDriver.read(networkReadBuffer, applicationBuffer); + assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), applicationBuffer.sliceBuffersTo(4)[0]); + applicationBuffer.release(4); normalClose(clientDriver, serverDriver); } @@ -108,18 +131,22 @@ public class SSLDriverTests extends ESTestCase { ByteBuffer buffer = ByteBuffer.allocate(1 << 15); for (int i = 0; i < (1 << 15); ++i) { - buffer.put((byte) i); + buffer.put((byte) (i % 127)); } + buffer.flip(); ByteBuffer[] buffers = {buffer}; - sendAppData(clientDriver, serverDriver, buffers); - serverDriver.read(serverBuffer); - assertEquals(16384, serverBuffer.sliceBuffersFrom(0)[0].limit()); - assertEquals(16384, serverBuffer.sliceBuffersFrom(0)[1].limit()); + sendAppData(clientDriver, buffers); + serverDriver.read(networkReadBuffer, applicationBuffer); + ByteBuffer[] buffers1 = applicationBuffer.sliceBuffersFrom(0); + assertEquals((byte) (16383 % 127), buffers1[0].get(16383)); + assertEquals((byte) (32767 % 127), buffers1[1].get(16383)); + applicationBuffer.release(1 << 15); ByteBuffer[] buffers2 = {ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8))}; - sendAppData(serverDriver, clientDriver, buffers2); - clientDriver.read(clientBuffer); - assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), clientBuffer.sliceBuffersTo(4)[0]); + sendAppData(serverDriver, buffers2); + clientDriver.read(networkReadBuffer, applicationBuffer); + assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), applicationBuffer.sliceBuffersTo(4)[0]); + applicationBuffer.release(4); assertFalse(clientDriver.needsNonApplicationWrite()); normalClose(clientDriver, serverDriver); @@ -193,16 +220,16 @@ public class SSLDriverTests extends ESTestCase { serverDriver.initiateClose(); assertTrue(serverDriver.needsNonApplicationWrite()); assertFalse(serverDriver.isClosed()); - sendNonApplicationWrites(serverDriver, clientDriver); + sendNonApplicationWrites(serverDriver); // We are immediately fully closed due to SSLEngine inconsistency assertTrue(serverDriver.isClosed()); - // This should not throw exception yet as the SSLEngine will not UNWRAP data while attempting to WRAP - clientDriver.read(clientBuffer); - sendNonApplicationWrites(clientDriver, serverDriver); - clientDriver.read(clientBuffer); - sendNonApplicationWrites(clientDriver, serverDriver); - serverDriver.read(serverBuffer); + + SSLException sslException = expectThrows(SSLException.class, () -> clientDriver.read(networkReadBuffer, applicationBuffer)); + assertEquals("Received close_notify during handshake", sslException.getMessage()); + sendNonApplicationWrites(clientDriver); assertTrue(clientDriver.isClosed()); + + serverDriver.read(networkReadBuffer, applicationBuffer); } public void testCloseDuringHandshakePreJDK11() throws Exception { @@ -226,26 +253,28 @@ public class SSLDriverTests extends ESTestCase { serverDriver.initiateClose(); assertTrue(serverDriver.needsNonApplicationWrite()); assertFalse(serverDriver.isClosed()); - sendNonApplicationWrites(serverDriver, clientDriver); + sendNonApplicationWrites(serverDriver); // We are immediately fully closed due to SSLEngine inconsistency assertTrue(serverDriver.isClosed()); - SSLException sslException = expectThrows(SSLException.class, () -> clientDriver.read(clientBuffer)); + // This should not throw exception yet as the SSLEngine will not UNWRAP data while attempting to WRAP + + SSLException sslException = expectThrows(SSLException.class, () -> clientDriver.read(networkReadBuffer, applicationBuffer)); assertEquals("Received close_notify during handshake", sslException.getMessage()); - assertTrue(clientDriver.needsNonApplicationWrite()); - sendNonApplicationWrites(clientDriver, serverDriver); - serverDriver.read(serverBuffer); + sendNonApplicationWrites(clientDriver); assertTrue(clientDriver.isClosed()); + + serverDriver.read(networkReadBuffer, applicationBuffer); } private void failedCloseAlert(SSLDriver sendDriver, SSLDriver receiveDriver, List messages) throws SSLException { assertTrue(sendDriver.needsNonApplicationWrite()); assertFalse(sendDriver.isClosed()); - sendNonApplicationWrites(sendDriver, receiveDriver); + sendNonApplicationWrites(sendDriver); assertTrue(sendDriver.isClosed()); sendDriver.close(); - SSLException sslException = expectThrows(SSLException.class, () -> receiveDriver.read(genericBuffer)); + SSLException sslException = expectThrows(SSLException.class, () -> receiveDriver.read(networkReadBuffer, applicationBuffer)); assertTrue("Expected one of the following exception messages: " + messages + ". Found: " + sslException.getMessage(), messages.stream().anyMatch(m -> sslException.getMessage().equals(m))); if (receiveDriver.needsNonApplicationWrite() == false) { @@ -274,29 +303,30 @@ public class SSLDriverTests extends ESTestCase { sendDriver.initiateClose(); assertFalse(sendDriver.readyForApplicationWrites()); assertTrue(sendDriver.needsNonApplicationWrite()); - sendNonApplicationWrites(sendDriver, receiveDriver); + sendNonApplicationWrites(sendDriver); assertFalse(sendDriver.isClosed()); - receiveDriver.read(genericBuffer); + receiveDriver.read(networkReadBuffer, applicationBuffer); assertFalse(receiveDriver.isClosed()); assertFalse(receiveDriver.readyForApplicationWrites()); assertTrue(receiveDriver.needsNonApplicationWrite()); - sendNonApplicationWrites(receiveDriver, sendDriver); + sendNonApplicationWrites(receiveDriver); assertTrue(receiveDriver.isClosed()); - sendDriver.read(genericBuffer); + sendDriver.read(networkReadBuffer, applicationBuffer); assertTrue(sendDriver.isClosed()); sendDriver.close(); receiveDriver.close(); + assertEquals(0, openPages.get()); } - private void sendNonApplicationWrites(SSLDriver sendDriver, SSLDriver receiveDriver) throws SSLException { + private void sendNonApplicationWrites(SSLDriver sendDriver) throws SSLException { SSLOutboundBuffer outboundBuffer = sendDriver.getOutboundBuffer(); while (sendDriver.needsNonApplicationWrite() || outboundBuffer.hasEncryptedBytesToFlush()) { if (outboundBuffer.hasEncryptedBytesToFlush()) { - sendData(outboundBuffer.buildNetworkFlushOperation(), receiveDriver); + sendData(outboundBuffer.buildNetworkFlushOperation()); } else { sendDriver.nonApplicationWrite(); } @@ -342,8 +372,8 @@ public class SSLDriverTests extends ESTestCase { while (sendDriver.needsNonApplicationWrite() || outboundBuffer.hasEncryptedBytesToFlush()) { if (outboundBuffer.hasEncryptedBytesToFlush()) { - sendData(outboundBuffer.buildNetworkFlushOperation(), receiveDriver); - receiveDriver.read(genericBuffer); + sendData(outboundBuffer.buildNetworkFlushOperation()); + receiveDriver.read(networkReadBuffer, applicationBuffer); } else { sendDriver.nonApplicationWrite(); } @@ -353,37 +383,46 @@ public class SSLDriverTests extends ESTestCase { } } - private void sendAppData(SSLDriver sendDriver, SSLDriver receiveDriver, ByteBuffer[] message) throws IOException { + private void sendAppData(SSLDriver sendDriver, ByteBuffer[] message) throws IOException { assertFalse(sendDriver.needsNonApplicationWrite()); - int bytesToEncrypt = Arrays.stream(message).mapToInt(Buffer::remaining).sum(); - SSLOutboundBuffer outboundBuffer = sendDriver.getOutboundBuffer(); FlushOperation flushOperation = new FlushOperation(message, (r, l) -> {}); - int bytesEncrypted = 0; - while (bytesToEncrypt > bytesEncrypted) { - bytesEncrypted += sendDriver.write(flushOperation); - sendData(outboundBuffer.buildNetworkFlushOperation(), receiveDriver); + while (flushOperation.isFullyFlushed() == false) { + sendDriver.write(flushOperation); } + sendData(sendDriver.getOutboundBuffer().buildNetworkFlushOperation()); } - private void sendData(FlushOperation flushOperation, SSLDriver receiveDriver) { - ByteBuffer readBuffer = receiveDriver.getNetworkReadBuffer(); + private void sendData(FlushOperation flushOperation) { ByteBuffer[] writeBuffers = flushOperation.getBuffersToWrite(); - int bytesToEncrypt = Arrays.stream(writeBuffers).mapToInt(Buffer::remaining).sum(); - assert bytesToEncrypt < readBuffer.capacity() : "Flush operation must be less that read buffer"; + int bytesToCopy = Arrays.stream(writeBuffers).mapToInt(Buffer::remaining).sum(); + networkReadBuffer.ensureCapacity(bytesToCopy + networkReadBuffer.getIndex()); + ByteBuffer[] byteBuffers = networkReadBuffer.sliceBuffersFrom(0); assert writeBuffers.length > 0 : "No write buffers"; - for (ByteBuffer writeBuffer : writeBuffers) { - int written = writeBuffer.remaining(); + int r = 0; + while (flushOperation.isFullyFlushed() == false) { + ByteBuffer readBuffer = byteBuffers[r]; + ByteBuffer writeBuffer = flushOperation.getBuffersToWrite()[0]; + int toWrite = Math.min(writeBuffer.remaining(), readBuffer.remaining()); + writeBuffer.limit(writeBuffer.position() + toWrite); readBuffer.put(writeBuffer); - flushOperation.incrementIndex(written); + flushOperation.incrementIndex(toWrite); + if (readBuffer.remaining() == 0) { + r++; + } } + networkReadBuffer.incrementIndex(bytesToCopy); assertTrue(flushOperation.isFullyFlushed()); + flushOperation.getListener().accept(null, null); } private SSLDriver getDriver(SSLEngine engine, boolean isClient) { - return new SSLDriver(engine, isClient); + return new SSLDriver(engine, (n) -> { + openPages.incrementAndGet(); + return new Page(ByteBuffer.allocate(n), openPages::decrementAndGet); + }, isClient); } }