Read multiple TLS packets in one read call (#41820)
This is related to #27260. Currently we have a single read buffer that is no larger than a single TLS packet. This prevents us from reading multiple TLS packets in a single socket read call. This commit modifies our TLS work to support reading similar to the plaintext case. The data will be copied to a (potentially) recycled TLS packet-sized buffer for interaction with the SSLEngine.
This commit is contained in:
parent
228d23de6d
commit
927013426a
|
@ -27,7 +27,7 @@ import java.util.ArrayList;
|
|||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.function.IntFunction;
|
||||
|
||||
/**
|
||||
* This is a channel byte buffer composed internally of 16kb pages. When an entire message has been read
|
||||
|
@ -37,15 +37,14 @@ import java.util.function.Supplier;
|
|||
*/
|
||||
public final class InboundChannelBuffer implements AutoCloseable {
|
||||
|
||||
private static final int PAGE_SIZE = 1 << 14;
|
||||
public static final int PAGE_SIZE = 1 << 14;
|
||||
private static final int PAGE_MASK = PAGE_SIZE - 1;
|
||||
private static final int PAGE_SHIFT = Integer.numberOfTrailingZeros(PAGE_SIZE);
|
||||
private static final ByteBuffer[] EMPTY_BYTE_BUFFER_ARRAY = new ByteBuffer[0];
|
||||
private static final Page[] EMPTY_BYTE_PAGE_ARRAY = new Page[0];
|
||||
|
||||
|
||||
private final ArrayDeque<Page> pages;
|
||||
private final Supplier<Page> pageSupplier;
|
||||
private final IntFunction<Page> pageAllocator;
|
||||
private final ArrayDeque<Page> pages = new ArrayDeque<>();
|
||||
private final AtomicBoolean isClosed = new AtomicBoolean(false);
|
||||
|
||||
private long capacity = 0;
|
||||
|
@ -53,14 +52,12 @@ public final class InboundChannelBuffer implements AutoCloseable {
|
|||
// The offset is an int as it is the offset of where the bytes begin in the first buffer
|
||||
private int offset = 0;
|
||||
|
||||
public InboundChannelBuffer(Supplier<Page> pageSupplier) {
|
||||
this.pageSupplier = pageSupplier;
|
||||
this.pages = new ArrayDeque<>();
|
||||
this.capacity = PAGE_SIZE * pages.size();
|
||||
public InboundChannelBuffer(IntFunction<Page> pageAllocator) {
|
||||
this.pageAllocator = pageAllocator;
|
||||
}
|
||||
|
||||
public static InboundChannelBuffer allocatingInstance() {
|
||||
return new InboundChannelBuffer(() -> new Page(ByteBuffer.allocate(PAGE_SIZE), () -> {}));
|
||||
return new InboundChannelBuffer((n) -> new Page(ByteBuffer.allocate(n), () -> {}));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -87,7 +84,7 @@ public final class InboundChannelBuffer implements AutoCloseable {
|
|||
int numPages = numPages(requiredCapacity + offset);
|
||||
int pagesToAdd = numPages - pages.size();
|
||||
for (int i = 0; i < pagesToAdd; i++) {
|
||||
Page page = pageSupplier.get();
|
||||
Page page = pageAllocator.apply(PAGE_SIZE);
|
||||
pages.addLast(page);
|
||||
}
|
||||
capacity += pagesToAdd * PAGE_SIZE;
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
package org.elasticsearch.nio;
|
||||
|
||||
import org.elasticsearch.common.concurrent.CompletableContext;
|
||||
import org.elasticsearch.nio.utils.ByteBufferUtils;
|
||||
import org.elasticsearch.nio.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -249,26 +250,6 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
|
|||
// data that is copied to the buffer for a write, but not successfully flushed immediately, must be
|
||||
// copied again on the next call.
|
||||
|
||||
protected int readFromChannel(ByteBuffer buffer) throws IOException {
|
||||
ByteBuffer ioBuffer = getSelector().getIoBuffer();
|
||||
ioBuffer.limit(Math.min(buffer.remaining(), ioBuffer.limit()));
|
||||
int bytesRead;
|
||||
try {
|
||||
bytesRead = rawChannel.read(ioBuffer);
|
||||
} catch (IOException e) {
|
||||
closeNow = true;
|
||||
throw e;
|
||||
}
|
||||
if (bytesRead < 0) {
|
||||
closeNow = true;
|
||||
return 0;
|
||||
} else {
|
||||
ioBuffer.flip();
|
||||
buffer.put(ioBuffer);
|
||||
return bytesRead;
|
||||
}
|
||||
}
|
||||
|
||||
protected int readFromChannel(InboundChannelBuffer channelBuffer) throws IOException {
|
||||
ByteBuffer ioBuffer = getSelector().getIoBuffer();
|
||||
int bytesRead;
|
||||
|
@ -288,7 +269,7 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
|
|||
int j = 0;
|
||||
while (j < buffers.length && ioBuffer.remaining() > 0) {
|
||||
ByteBuffer buffer = buffers[j++];
|
||||
copyBytes(ioBuffer, buffer);
|
||||
ByteBufferUtils.copyBytes(ioBuffer, buffer);
|
||||
}
|
||||
channelBuffer.incrementIndex(bytesRead);
|
||||
return bytesRead;
|
||||
|
@ -299,24 +280,6 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
|
|||
// copying.
|
||||
private final int WRITE_LIMIT = 1 << 16;
|
||||
|
||||
protected int flushToChannel(ByteBuffer buffer) throws IOException {
|
||||
int initialPosition = buffer.position();
|
||||
ByteBuffer ioBuffer = getSelector().getIoBuffer();
|
||||
ioBuffer.limit(Math.min(WRITE_LIMIT, ioBuffer.limit()));
|
||||
copyBytes(buffer, ioBuffer);
|
||||
ioBuffer.flip();
|
||||
int bytesWritten;
|
||||
try {
|
||||
bytesWritten = rawChannel.write(ioBuffer);
|
||||
} catch (IOException e) {
|
||||
closeNow = true;
|
||||
buffer.position(initialPosition);
|
||||
throw e;
|
||||
}
|
||||
buffer.position(initialPosition + bytesWritten);
|
||||
return bytesWritten;
|
||||
}
|
||||
|
||||
protected int flushToChannel(FlushOperation flushOperation) throws IOException {
|
||||
ByteBuffer ioBuffer = getSelector().getIoBuffer();
|
||||
|
||||
|
@ -325,12 +288,8 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
|
|||
while (continueFlush) {
|
||||
ioBuffer.clear();
|
||||
ioBuffer.limit(Math.min(WRITE_LIMIT, ioBuffer.limit()));
|
||||
int j = 0;
|
||||
ByteBuffer[] buffers = flushOperation.getBuffersToWrite(WRITE_LIMIT);
|
||||
while (j < buffers.length && ioBuffer.remaining() > 0) {
|
||||
ByteBuffer buffer = buffers[j++];
|
||||
copyBytes(buffer, ioBuffer);
|
||||
}
|
||||
ByteBufferUtils.copyBytes(buffers, ioBuffer);
|
||||
ioBuffer.flip();
|
||||
int bytesFlushed;
|
||||
try {
|
||||
|
@ -345,12 +304,4 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
|
|||
}
|
||||
return totalBytesFlushed;
|
||||
}
|
||||
|
||||
private void copyBytes(ByteBuffer from, ByteBuffer to) {
|
||||
int nBytesToCopy = Math.min(to.remaining(), from.remaining());
|
||||
int initialLimit = from.limit();
|
||||
from.limit(from.position() + nBytesToCopy);
|
||||
to.put(from);
|
||||
from.limit(initialLimit);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.nio.utils;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
public final class ByteBufferUtils {
|
||||
|
||||
private ByteBufferUtils() {}
|
||||
|
||||
/**
|
||||
* Copies bytes from the array of byte buffers into the destination buffer. The number of bytes copied is
|
||||
* limited by the bytes available to copy and the space remaining in the destination byte buffer.
|
||||
*
|
||||
* @param source byte buffers to copy from
|
||||
* @param destination byte buffer to copy to
|
||||
*
|
||||
* @return number of bytes copied
|
||||
*/
|
||||
public static long copyBytes(ByteBuffer[] source, ByteBuffer destination) {
|
||||
long bytesCopied = 0;
|
||||
for (int i = 0; i < source.length && destination.hasRemaining(); i++) {
|
||||
ByteBuffer buffer = source[i];
|
||||
bytesCopied += copyBytes(buffer, destination);
|
||||
}
|
||||
return bytesCopied;
|
||||
}
|
||||
|
||||
/**
|
||||
* Copies bytes from source byte buffer into the destination buffer. The number of bytes copied is
|
||||
* limited by the bytes available to copy and the space remaining in the destination byte buffer.
|
||||
*
|
||||
* @param source byte buffer to copy from
|
||||
* @param destination byte buffer to copy to
|
||||
*
|
||||
* @return number of bytes copied
|
||||
*/
|
||||
public static int copyBytes(ByteBuffer source, ByteBuffer destination) {
|
||||
int nBytesToCopy = Math.min(destination.remaining(), source.remaining());
|
||||
int initialLimit = source.limit();
|
||||
source.limit(source.position() + nBytesToCopy);
|
||||
destination.put(source);
|
||||
source.limit(initialLimit);
|
||||
return nBytesToCopy;
|
||||
}
|
||||
}
|
|
@ -19,23 +19,25 @@
|
|||
|
||||
package org.elasticsearch.nio;
|
||||
|
||||
import org.elasticsearch.common.util.PageCacheRecycler;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.concurrent.ConcurrentLinkedQueue;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.function.IntFunction;
|
||||
|
||||
public class InboundChannelBufferTests extends ESTestCase {
|
||||
|
||||
private static final int PAGE_SIZE = PageCacheRecycler.PAGE_SIZE_IN_BYTES;
|
||||
private final Supplier<Page> defaultPageSupplier = () ->
|
||||
new Page(ByteBuffer.allocate(PageCacheRecycler.BYTE_PAGE_SIZE), () -> {
|
||||
});
|
||||
private IntFunction<Page> defaultPageAllocator;
|
||||
|
||||
@Override
|
||||
public void setUp() throws Exception {
|
||||
super.setUp();
|
||||
defaultPageAllocator = (n) -> new Page(ByteBuffer.allocate(n), () -> {});
|
||||
}
|
||||
|
||||
public void testNewBufferNoPages() {
|
||||
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageSupplier);
|
||||
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageAllocator);
|
||||
|
||||
assertEquals(0, channelBuffer.getCapacity());
|
||||
assertEquals(0, channelBuffer.getRemaining());
|
||||
|
@ -43,107 +45,107 @@ public class InboundChannelBufferTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testExpandCapacity() {
|
||||
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageSupplier);
|
||||
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageAllocator);
|
||||
assertEquals(0, channelBuffer.getCapacity());
|
||||
assertEquals(0, channelBuffer.getRemaining());
|
||||
|
||||
channelBuffer.ensureCapacity(PAGE_SIZE);
|
||||
channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE);
|
||||
|
||||
assertEquals(PAGE_SIZE, channelBuffer.getCapacity());
|
||||
assertEquals(PAGE_SIZE, channelBuffer.getRemaining());
|
||||
assertEquals(InboundChannelBuffer.PAGE_SIZE, channelBuffer.getCapacity());
|
||||
assertEquals(InboundChannelBuffer.PAGE_SIZE, channelBuffer.getRemaining());
|
||||
|
||||
channelBuffer.ensureCapacity(PAGE_SIZE + 1);
|
||||
channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE + 1);
|
||||
|
||||
assertEquals(PAGE_SIZE * 2, channelBuffer.getCapacity());
|
||||
assertEquals(PAGE_SIZE * 2, channelBuffer.getRemaining());
|
||||
assertEquals(InboundChannelBuffer.PAGE_SIZE * 2, channelBuffer.getCapacity());
|
||||
assertEquals(InboundChannelBuffer.PAGE_SIZE * 2, channelBuffer.getRemaining());
|
||||
}
|
||||
|
||||
public void testExpandCapacityMultiplePages() {
|
||||
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageSupplier);
|
||||
channelBuffer.ensureCapacity(PAGE_SIZE);
|
||||
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageAllocator);
|
||||
channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE);
|
||||
|
||||
assertEquals(PAGE_SIZE, channelBuffer.getCapacity());
|
||||
assertEquals(InboundChannelBuffer.PAGE_SIZE, channelBuffer.getCapacity());
|
||||
|
||||
int multiple = randomInt(80);
|
||||
channelBuffer.ensureCapacity(PAGE_SIZE + ((multiple * PAGE_SIZE) - randomInt(500)));
|
||||
channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE + ((multiple * InboundChannelBuffer.PAGE_SIZE) - randomInt(500)));
|
||||
|
||||
assertEquals(PAGE_SIZE * (multiple + 1), channelBuffer.getCapacity());
|
||||
assertEquals(PAGE_SIZE * (multiple + 1), channelBuffer.getRemaining());
|
||||
assertEquals(InboundChannelBuffer.PAGE_SIZE * (multiple + 1), channelBuffer.getCapacity());
|
||||
assertEquals(InboundChannelBuffer.PAGE_SIZE * (multiple + 1), channelBuffer.getRemaining());
|
||||
}
|
||||
|
||||
public void testExpandCapacityRespectsOffset() {
|
||||
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageSupplier);
|
||||
channelBuffer.ensureCapacity(PAGE_SIZE);
|
||||
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageAllocator);
|
||||
channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE);
|
||||
|
||||
assertEquals(PAGE_SIZE, channelBuffer.getCapacity());
|
||||
assertEquals(PAGE_SIZE, channelBuffer.getRemaining());
|
||||
assertEquals(InboundChannelBuffer.PAGE_SIZE, channelBuffer.getCapacity());
|
||||
assertEquals(InboundChannelBuffer.PAGE_SIZE, channelBuffer.getRemaining());
|
||||
|
||||
int offset = randomInt(300);
|
||||
|
||||
channelBuffer.release(offset);
|
||||
|
||||
assertEquals(PAGE_SIZE - offset, channelBuffer.getCapacity());
|
||||
assertEquals(PAGE_SIZE - offset, channelBuffer.getRemaining());
|
||||
assertEquals(InboundChannelBuffer.PAGE_SIZE - offset, channelBuffer.getCapacity());
|
||||
assertEquals(InboundChannelBuffer.PAGE_SIZE - offset, channelBuffer.getRemaining());
|
||||
|
||||
channelBuffer.ensureCapacity(PAGE_SIZE + 1);
|
||||
channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE + 1);
|
||||
|
||||
assertEquals(PAGE_SIZE * 2 - offset, channelBuffer.getCapacity());
|
||||
assertEquals(PAGE_SIZE * 2 - offset, channelBuffer.getRemaining());
|
||||
assertEquals(InboundChannelBuffer.PAGE_SIZE * 2 - offset, channelBuffer.getCapacity());
|
||||
assertEquals(InboundChannelBuffer.PAGE_SIZE * 2 - offset, channelBuffer.getRemaining());
|
||||
}
|
||||
|
||||
public void testIncrementIndex() {
|
||||
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageSupplier);
|
||||
channelBuffer.ensureCapacity(PAGE_SIZE);
|
||||
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageAllocator);
|
||||
channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE);
|
||||
|
||||
assertEquals(0, channelBuffer.getIndex());
|
||||
assertEquals(PAGE_SIZE, channelBuffer.getRemaining());
|
||||
assertEquals(InboundChannelBuffer.PAGE_SIZE, channelBuffer.getRemaining());
|
||||
|
||||
channelBuffer.incrementIndex(10);
|
||||
|
||||
assertEquals(10, channelBuffer.getIndex());
|
||||
assertEquals(PAGE_SIZE - 10, channelBuffer.getRemaining());
|
||||
assertEquals(InboundChannelBuffer.PAGE_SIZE - 10, channelBuffer.getRemaining());
|
||||
}
|
||||
|
||||
public void testIncrementIndexWithOffset() {
|
||||
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageSupplier);
|
||||
channelBuffer.ensureCapacity(PAGE_SIZE);
|
||||
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageAllocator);
|
||||
channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE);
|
||||
|
||||
assertEquals(0, channelBuffer.getIndex());
|
||||
assertEquals(PAGE_SIZE, channelBuffer.getRemaining());
|
||||
assertEquals(InboundChannelBuffer.PAGE_SIZE, channelBuffer.getRemaining());
|
||||
|
||||
channelBuffer.release(10);
|
||||
assertEquals(PAGE_SIZE - 10, channelBuffer.getRemaining());
|
||||
assertEquals(InboundChannelBuffer.PAGE_SIZE - 10, channelBuffer.getRemaining());
|
||||
|
||||
channelBuffer.incrementIndex(10);
|
||||
|
||||
assertEquals(10, channelBuffer.getIndex());
|
||||
assertEquals(PAGE_SIZE - 20, channelBuffer.getRemaining());
|
||||
assertEquals(InboundChannelBuffer.PAGE_SIZE - 20, channelBuffer.getRemaining());
|
||||
|
||||
channelBuffer.release(2);
|
||||
assertEquals(8, channelBuffer.getIndex());
|
||||
assertEquals(PAGE_SIZE - 20, channelBuffer.getRemaining());
|
||||
assertEquals(InboundChannelBuffer.PAGE_SIZE - 20, channelBuffer.getRemaining());
|
||||
}
|
||||
|
||||
public void testReleaseClosesPages() {
|
||||
ConcurrentLinkedQueue<AtomicBoolean> queue = new ConcurrentLinkedQueue<>();
|
||||
Supplier<Page> supplier = () -> {
|
||||
IntFunction<Page> allocator = (n) -> {
|
||||
AtomicBoolean atomicBoolean = new AtomicBoolean();
|
||||
queue.add(atomicBoolean);
|
||||
return new Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true));
|
||||
return new Page(ByteBuffer.allocate(n), () -> atomicBoolean.set(true));
|
||||
};
|
||||
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(supplier);
|
||||
channelBuffer.ensureCapacity(PAGE_SIZE * 4);
|
||||
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(allocator);
|
||||
channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE * 4);
|
||||
|
||||
assertEquals(PAGE_SIZE * 4, channelBuffer.getCapacity());
|
||||
assertEquals(InboundChannelBuffer.PAGE_SIZE * 4, channelBuffer.getCapacity());
|
||||
assertEquals(4, queue.size());
|
||||
|
||||
for (AtomicBoolean closedRef : queue) {
|
||||
assertFalse(closedRef.get());
|
||||
}
|
||||
|
||||
channelBuffer.release(2 * PAGE_SIZE);
|
||||
channelBuffer.release(2 * InboundChannelBuffer.PAGE_SIZE);
|
||||
|
||||
assertEquals(PAGE_SIZE * 2, channelBuffer.getCapacity());
|
||||
assertEquals(InboundChannelBuffer.PAGE_SIZE * 2, channelBuffer.getCapacity());
|
||||
|
||||
assertTrue(queue.poll().get());
|
||||
assertTrue(queue.poll().get());
|
||||
|
@ -153,13 +155,13 @@ public class InboundChannelBufferTests extends ESTestCase {
|
|||
|
||||
public void testClose() {
|
||||
ConcurrentLinkedQueue<AtomicBoolean> queue = new ConcurrentLinkedQueue<>();
|
||||
Supplier<Page> supplier = () -> {
|
||||
IntFunction<Page> allocator = (n) -> {
|
||||
AtomicBoolean atomicBoolean = new AtomicBoolean();
|
||||
queue.add(atomicBoolean);
|
||||
return new Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true));
|
||||
return new Page(ByteBuffer.allocate(n), () -> atomicBoolean.set(true));
|
||||
};
|
||||
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(supplier);
|
||||
channelBuffer.ensureCapacity(PAGE_SIZE * 4);
|
||||
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(allocator);
|
||||
channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE * 4);
|
||||
|
||||
assertEquals(4, queue.size());
|
||||
|
||||
|
@ -178,13 +180,13 @@ public class InboundChannelBufferTests extends ESTestCase {
|
|||
|
||||
public void testCloseRetainedPages() {
|
||||
ConcurrentLinkedQueue<AtomicBoolean> queue = new ConcurrentLinkedQueue<>();
|
||||
Supplier<Page> supplier = () -> {
|
||||
IntFunction<Page> allocator = (n) -> {
|
||||
AtomicBoolean atomicBoolean = new AtomicBoolean();
|
||||
queue.add(atomicBoolean);
|
||||
return new Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true));
|
||||
return new Page(ByteBuffer.allocate(n), () -> atomicBoolean.set(true));
|
||||
};
|
||||
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(supplier);
|
||||
channelBuffer.ensureCapacity(PAGE_SIZE * 4);
|
||||
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(allocator);
|
||||
channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE * 4);
|
||||
|
||||
assertEquals(4, queue.size());
|
||||
|
||||
|
@ -192,7 +194,7 @@ public class InboundChannelBufferTests extends ESTestCase {
|
|||
assertFalse(closedRef.get());
|
||||
}
|
||||
|
||||
Page[] pages = channelBuffer.sliceAndRetainPagesTo(PAGE_SIZE * 2);
|
||||
Page[] pages = channelBuffer.sliceAndRetainPagesTo(InboundChannelBuffer.PAGE_SIZE * 2);
|
||||
|
||||
pages[1].close();
|
||||
|
||||
|
@ -220,10 +222,10 @@ public class InboundChannelBufferTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testAccessByteBuffers() {
|
||||
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageSupplier);
|
||||
InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageAllocator);
|
||||
|
||||
int pages = randomInt(50) + 5;
|
||||
channelBuffer.ensureCapacity(pages * PAGE_SIZE);
|
||||
channelBuffer.ensureCapacity(pages * InboundChannelBuffer.PAGE_SIZE);
|
||||
|
||||
long capacity = channelBuffer.getCapacity();
|
||||
|
||||
|
|
|
@ -34,8 +34,8 @@ import java.util.concurrent.atomic.AtomicBoolean;
|
|||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import java.util.function.BiConsumer;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.function.IntFunction;
|
||||
import java.util.function.Predicate;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
import static org.mockito.Matchers.any;
|
||||
import static org.mockito.Matchers.anyInt;
|
||||
|
@ -285,8 +285,8 @@ public class SocketChannelContextTests extends ESTestCase {
|
|||
when(channel.getRawChannel()).thenReturn(realChannel);
|
||||
when(channel.isOpen()).thenReturn(true);
|
||||
Runnable closer = mock(Runnable.class);
|
||||
Supplier<Page> pageSupplier = () -> new Page(ByteBuffer.allocate(1 << 14), closer);
|
||||
InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier);
|
||||
IntFunction<Page> pageAllocator = (n) -> new Page(ByteBuffer.allocate(n), closer);
|
||||
InboundChannelBuffer buffer = new InboundChannelBuffer(pageAllocator);
|
||||
buffer.ensureCapacity(1);
|
||||
TestSocketChannelContext context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, buffer);
|
||||
context.closeFromSelector();
|
||||
|
@ -294,29 +294,6 @@ public class SocketChannelContextTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testReadToBufferLimitsToPassedBuffer() throws IOException {
|
||||
ByteBuffer buffer = ByteBuffer.allocate(10);
|
||||
when(rawChannel.read(any(ByteBuffer.class))).thenAnswer(completelyFillBufferAnswer());
|
||||
|
||||
int bytesRead = context.readFromChannel(buffer);
|
||||
assertEquals(bytesRead, 10);
|
||||
assertEquals(0, buffer.remaining());
|
||||
}
|
||||
|
||||
public void testReadToBufferHandlesIOException() throws IOException {
|
||||
when(rawChannel.read(any(ByteBuffer.class))).thenThrow(new IOException());
|
||||
|
||||
expectThrows(IOException.class, () -> context.readFromChannel(ByteBuffer.allocate(10)));
|
||||
assertTrue(context.closeNow());
|
||||
}
|
||||
|
||||
public void testReadToBufferHandlesEOF() throws IOException {
|
||||
when(rawChannel.read(any(ByteBuffer.class))).thenReturn(-1);
|
||||
|
||||
context.readFromChannel(ByteBuffer.allocate(10));
|
||||
assertTrue(context.closeNow());
|
||||
}
|
||||
|
||||
public void testReadToChannelBufferWillReadAsMuchAsIOBufferAllows() throws IOException {
|
||||
when(rawChannel.read(any(ByteBuffer.class))).thenAnswer(completelyFillBufferAnswer());
|
||||
|
||||
|
@ -344,33 +321,6 @@ public class SocketChannelContextTests extends ESTestCase {
|
|||
assertEquals(0, channelBuffer.getIndex());
|
||||
}
|
||||
|
||||
public void testFlushBufferHandlesPartialFlush() throws IOException {
|
||||
int bytesToConsume = 3;
|
||||
when(rawChannel.write(any(ByteBuffer.class))).thenAnswer(consumeBufferAnswer(bytesToConsume));
|
||||
|
||||
ByteBuffer buffer = ByteBuffer.allocate(10);
|
||||
context.flushToChannel(buffer);
|
||||
assertEquals(10 - bytesToConsume, buffer.remaining());
|
||||
}
|
||||
|
||||
public void testFlushBufferHandlesFullFlush() throws IOException {
|
||||
int bytesToConsume = 10;
|
||||
when(rawChannel.write(any(ByteBuffer.class))).thenAnswer(consumeBufferAnswer(bytesToConsume));
|
||||
|
||||
ByteBuffer buffer = ByteBuffer.allocate(10);
|
||||
context.flushToChannel(buffer);
|
||||
assertEquals(0, buffer.remaining());
|
||||
}
|
||||
|
||||
public void testFlushBufferHandlesIOException() throws IOException {
|
||||
when(rawChannel.write(any(ByteBuffer.class))).thenThrow(new IOException());
|
||||
|
||||
ByteBuffer buffer = ByteBuffer.allocate(10);
|
||||
expectThrows(IOException.class, () -> context.flushToChannel(buffer));
|
||||
assertTrue(context.closeNow());
|
||||
assertEquals(10, buffer.remaining());
|
||||
}
|
||||
|
||||
public void testFlushBuffersHandlesZeroFlush() throws IOException {
|
||||
when(rawChannel.write(any(ByteBuffer.class))).thenAnswer(consumeBufferAnswer(0));
|
||||
|
||||
|
@ -456,22 +406,14 @@ public class SocketChannelContextTests extends ESTestCase {
|
|||
|
||||
@Override
|
||||
public int read() throws IOException {
|
||||
if (randomBoolean()) {
|
||||
InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance();
|
||||
return readFromChannel(channelBuffer);
|
||||
} else {
|
||||
return readFromChannel(ByteBuffer.allocate(10));
|
||||
}
|
||||
InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance();
|
||||
return readFromChannel(channelBuffer);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void flushChannel() throws IOException {
|
||||
if (randomBoolean()) {
|
||||
ByteBuffer[] byteBuffers = {ByteBuffer.allocate(10)};
|
||||
flushToChannel(new FlushOperation(byteBuffers, (v, e) -> {}));
|
||||
} else {
|
||||
flushToChannel(ByteBuffer.allocate(10));
|
||||
}
|
||||
ByteBuffer[] byteBuffers = {ByteBuffer.allocate(10)};
|
||||
flushToChannel(new FlushOperation(byteBuffers, (v, e) -> {}));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -25,7 +25,6 @@ import org.apache.logging.log4j.Logger;
|
|||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.network.NetworkService;
|
||||
import org.elasticsearch.common.recycler.Recycler;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.settings.SettingsException;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
|
@ -43,16 +42,15 @@ import org.elasticsearch.nio.InboundChannelBuffer;
|
|||
import org.elasticsearch.nio.NioGroup;
|
||||
import org.elasticsearch.nio.NioSelector;
|
||||
import org.elasticsearch.nio.NioSocketChannel;
|
||||
import org.elasticsearch.nio.Page;
|
||||
import org.elasticsearch.nio.ServerChannelContext;
|
||||
import org.elasticsearch.nio.SocketChannelContext;
|
||||
import org.elasticsearch.rest.RestUtils;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.transport.nio.NioGroupFactory;
|
||||
import org.elasticsearch.transport.nio.PageAllocator;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.channels.ServerSocketChannel;
|
||||
import java.nio.channels.SocketChannel;
|
||||
import java.util.Arrays;
|
||||
|
@ -80,8 +78,8 @@ import static org.elasticsearch.http.nio.cors.NioCorsHandler.ANY_ORIGIN;
|
|||
public class NioHttpServerTransport extends AbstractHttpServerTransport {
|
||||
private static final Logger logger = LogManager.getLogger(NioHttpServerTransport.class);
|
||||
|
||||
protected final PageCacheRecycler pageCacheRecycler;
|
||||
protected final NioCorsConfig corsConfig;
|
||||
protected final PageAllocator pageAllocator;
|
||||
private final NioGroupFactory nioGroupFactory;
|
||||
|
||||
protected final boolean tcpNoDelay;
|
||||
|
@ -97,7 +95,7 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport {
|
|||
PageCacheRecycler pageCacheRecycler, ThreadPool threadPool, NamedXContentRegistry xContentRegistry,
|
||||
Dispatcher dispatcher, NioGroupFactory nioGroupFactory) {
|
||||
super(settings, networkService, bigArrays, threadPool, xContentRegistry, dispatcher);
|
||||
this.pageCacheRecycler = pageCacheRecycler;
|
||||
this.pageAllocator = new PageAllocator(pageCacheRecycler);
|
||||
this.nioGroupFactory = nioGroupFactory;
|
||||
|
||||
ByteSizeValue maxChunkSize = SETTING_HTTP_MAX_CHUNK_SIZE.get(settings);
|
||||
|
@ -206,15 +204,11 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport {
|
|||
@Override
|
||||
public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException {
|
||||
NioHttpChannel httpChannel = new NioHttpChannel(channel);
|
||||
java.util.function.Supplier<Page> pageSupplier = () -> {
|
||||
Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
|
||||
return new Page(ByteBuffer.wrap(bytes.v()), bytes::close);
|
||||
};
|
||||
HttpReadWriteHandler httpReadWritePipeline = new HttpReadWriteHandler(httpChannel,NioHttpServerTransport.this,
|
||||
handlingSettings, corsConfig, selector.getTaskScheduler(), threadPool::relativeTimeInMillis);
|
||||
Consumer<Exception> exceptionHandler = (e) -> onException(httpChannel, e);
|
||||
SocketChannelContext context = new BytesChannelContext(httpChannel, selector, exceptionHandler, httpReadWritePipeline,
|
||||
new InboundChannelBuffer(pageSupplier));
|
||||
new InboundChannelBuffer(pageAllocator));
|
||||
httpChannel.setContext(context);
|
||||
return httpChannel;
|
||||
}
|
||||
|
|
|
@ -26,7 +26,6 @@ import org.elasticsearch.Version;
|
|||
import org.elasticsearch.cluster.node.DiscoveryNode;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.network.NetworkService;
|
||||
import org.elasticsearch.common.recycler.Recycler;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.util.PageCacheRecycler;
|
||||
import org.elasticsearch.indices.breaker.CircuitBreakerService;
|
||||
|
@ -36,20 +35,17 @@ import org.elasticsearch.nio.InboundChannelBuffer;
|
|||
import org.elasticsearch.nio.NioGroup;
|
||||
import org.elasticsearch.nio.NioSelector;
|
||||
import org.elasticsearch.nio.NioSocketChannel;
|
||||
import org.elasticsearch.nio.Page;
|
||||
import org.elasticsearch.nio.ServerChannelContext;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.transport.TcpTransport;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.channels.ServerSocketChannel;
|
||||
import java.nio.channels.SocketChannel;
|
||||
import java.util.concurrent.ConcurrentMap;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.function.Function;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentMap;
|
||||
|
||||
|
@ -57,6 +53,7 @@ public class NioTransport extends TcpTransport {
|
|||
|
||||
private static final Logger logger = LogManager.getLogger(NioTransport.class);
|
||||
|
||||
protected final PageAllocator pageAllocator;
|
||||
private final ConcurrentMap<String, TcpChannelFactory> profileToChannelFactory = newConcurrentMap();
|
||||
private final NioGroupFactory groupFactory;
|
||||
private volatile NioGroup nioGroup;
|
||||
|
@ -66,6 +63,7 @@ public class NioTransport extends TcpTransport {
|
|||
PageCacheRecycler pageCacheRecycler, NamedWriteableRegistry namedWriteableRegistry,
|
||||
CircuitBreakerService circuitBreakerService, NioGroupFactory groupFactory) {
|
||||
super(settings, version, threadPool, pageCacheRecycler, circuitBreakerService, namedWriteableRegistry, networkService);
|
||||
this.pageAllocator = new PageAllocator(pageCacheRecycler);
|
||||
this.groupFactory = groupFactory;
|
||||
}
|
||||
|
||||
|
@ -158,14 +156,10 @@ public class NioTransport extends TcpTransport {
|
|||
@Override
|
||||
public NioTcpChannel createChannel(NioSelector selector, SocketChannel channel) {
|
||||
NioTcpChannel nioChannel = new NioTcpChannel(isClient == false, profileName, channel);
|
||||
Supplier<Page> pageSupplier = () -> {
|
||||
Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
|
||||
return new Page(ByteBuffer.wrap(bytes.v()), bytes::close);
|
||||
};
|
||||
TcpReadWriteHandler readWriteHandler = new TcpReadWriteHandler(nioChannel, NioTransport.this);
|
||||
Consumer<Exception> exceptionHandler = (e) -> onException(nioChannel, e);
|
||||
BytesChannelContext context = new BytesChannelContext(nioChannel, selector, exceptionHandler, readWriteHandler,
|
||||
new InboundChannelBuffer(pageSupplier));
|
||||
new InboundChannelBuffer(pageAllocator));
|
||||
nioChannel.setContext(context);
|
||||
return nioChannel;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.transport.nio;
|
||||
|
||||
import org.elasticsearch.common.recycler.Recycler;
|
||||
import org.elasticsearch.common.util.PageCacheRecycler;
|
||||
import org.elasticsearch.nio.Page;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.function.IntFunction;
|
||||
|
||||
public class PageAllocator implements IntFunction<Page> {
|
||||
|
||||
private static final int RECYCLE_LOWER_THRESHOLD = PageCacheRecycler.BYTE_PAGE_SIZE / 2;
|
||||
|
||||
private final PageCacheRecycler recycler;
|
||||
|
||||
public PageAllocator(PageCacheRecycler recycler) {
|
||||
this.recycler = recycler;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Page apply(int length) {
|
||||
if (length >= RECYCLE_LOWER_THRESHOLD && length <= PageCacheRecycler.BYTE_PAGE_SIZE){
|
||||
Recycler.V<byte[]> bytePage = recycler.bytePage(false);
|
||||
return new Page(ByteBuffer.wrap(bytePage.v(), 0, length), bytePage::close);
|
||||
} else {
|
||||
return new Page(ByteBuffer.allocate(length), () -> {});
|
||||
}
|
||||
}
|
||||
}
|
|
@ -37,8 +37,8 @@ import org.elasticsearch.nio.BytesChannelContext;
|
|||
import org.elasticsearch.nio.BytesWriteHandler;
|
||||
import org.elasticsearch.nio.ChannelFactory;
|
||||
import org.elasticsearch.nio.InboundChannelBuffer;
|
||||
import org.elasticsearch.nio.NioSelectorGroup;
|
||||
import org.elasticsearch.nio.NioSelector;
|
||||
import org.elasticsearch.nio.NioSelectorGroup;
|
||||
import org.elasticsearch.nio.NioServerSocketChannel;
|
||||
import org.elasticsearch.nio.NioSocketChannel;
|
||||
import org.elasticsearch.nio.Page;
|
||||
|
@ -61,7 +61,7 @@ import java.util.HashSet;
|
|||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentMap;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.function.IntFunction;
|
||||
|
||||
import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentMap;
|
||||
import static org.elasticsearch.common.util.concurrent.EsExecutors.daemonThreadFactory;
|
||||
|
@ -192,9 +192,13 @@ public class MockNioTransport extends TcpTransport {
|
|||
@Override
|
||||
public MockSocketChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException {
|
||||
MockSocketChannel nioChannel = new MockSocketChannel(isClient == false, profileName, channel);
|
||||
Supplier<Page> pageSupplier = () -> {
|
||||
Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
|
||||
return new Page(ByteBuffer.wrap(bytes.v()), bytes::close);
|
||||
IntFunction<Page> pageSupplier = (length) -> {
|
||||
if (length > PageCacheRecycler.BYTE_PAGE_SIZE) {
|
||||
return new Page(ByteBuffer.allocate(length), () -> {});
|
||||
} else {
|
||||
Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
|
||||
return new Page(ByteBuffer.wrap(bytes.v(), 0, length), bytes::close);
|
||||
}
|
||||
};
|
||||
MockTcpReadWriteHandler readWriteHandler = new MockTcpReadWriteHandler(nioChannel, MockNioTransport.this);
|
||||
BytesChannelContext context = new BytesChannelContext(nioChannel, selector, (e) -> exceptionCaught(nioChannel, e),
|
||||
|
|
|
@ -36,19 +36,22 @@ public final class SSLChannelContext extends SocketChannelContext {
|
|||
private static final Runnable DEFAULT_TIMEOUT_CANCELLER = () -> {};
|
||||
|
||||
private final SSLDriver sslDriver;
|
||||
private final InboundChannelBuffer networkReadBuffer;
|
||||
private final LinkedList<FlushOperation> encryptedFlushes = new LinkedList<>();
|
||||
private Runnable closeTimeoutCanceller = DEFAULT_TIMEOUT_CANCELLER;
|
||||
|
||||
SSLChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler, SSLDriver sslDriver,
|
||||
ReadWriteHandler readWriteHandler, InboundChannelBuffer channelBuffer) {
|
||||
this(channel, selector, exceptionHandler, sslDriver, readWriteHandler, channelBuffer, ALWAYS_ALLOW_CHANNEL);
|
||||
ReadWriteHandler readWriteHandler, InboundChannelBuffer applicationBuffer) {
|
||||
this(channel, selector, exceptionHandler, sslDriver, readWriteHandler, InboundChannelBuffer.allocatingInstance(),
|
||||
applicationBuffer, ALWAYS_ALLOW_CHANNEL);
|
||||
}
|
||||
|
||||
SSLChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler, SSLDriver sslDriver,
|
||||
ReadWriteHandler readWriteHandler, InboundChannelBuffer channelBuffer,
|
||||
ReadWriteHandler readWriteHandler, InboundChannelBuffer networkReadBuffer, InboundChannelBuffer channelBuffer,
|
||||
Predicate<NioSocketChannel> allowChannelPredicate) {
|
||||
super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer, allowChannelPredicate);
|
||||
this.sslDriver = sslDriver;
|
||||
this.networkReadBuffer = networkReadBuffer;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -157,12 +160,12 @@ public final class SSLChannelContext extends SocketChannelContext {
|
|||
if (closeNow()) {
|
||||
return bytesRead;
|
||||
}
|
||||
bytesRead = readFromChannel(sslDriver.getNetworkReadBuffer());
|
||||
bytesRead = readFromChannel(networkReadBuffer);
|
||||
if (bytesRead == 0) {
|
||||
return bytesRead;
|
||||
}
|
||||
|
||||
sslDriver.read(channelBuffer);
|
||||
sslDriver.read(networkReadBuffer, channelBuffer);
|
||||
|
||||
handleReadBytes();
|
||||
// It is possible that a read call produced non-application bytes to flush
|
||||
|
@ -201,7 +204,7 @@ public final class SSLChannelContext extends SocketChannelContext {
|
|||
getSelector().executeFailedListener(encryptedFlush.getListener(), new ClosedChannelException());
|
||||
}
|
||||
encryptedFlushes.clear();
|
||||
IOUtils.close(super::closeFromSelector, sslDriver::close);
|
||||
IOUtils.close(super::closeFromSelector, networkReadBuffer::close, sslDriver::close);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ package org.elasticsearch.xpack.security.transport.nio;
|
|||
import org.elasticsearch.nio.FlushOperation;
|
||||
import org.elasticsearch.nio.InboundChannelBuffer;
|
||||
import org.elasticsearch.nio.Page;
|
||||
import org.elasticsearch.nio.utils.ByteBufferUtils;
|
||||
import org.elasticsearch.nio.utils.ExceptionsHelper;
|
||||
|
||||
import javax.net.ssl.SSLEngine;
|
||||
|
@ -16,6 +17,7 @@ import javax.net.ssl.SSLException;
|
|||
import javax.net.ssl.SSLSession;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.function.IntFunction;
|
||||
|
||||
/**
|
||||
* SSLDriver is a class that wraps the {@link SSLEngine} and attempts to simplify the API. The basic usage is
|
||||
|
@ -27,9 +29,9 @@ import java.util.ArrayList;
|
|||
* application to be written to the wire.
|
||||
*
|
||||
* Handling reads from a channel with this class is very simple. When data has been read, call
|
||||
* {@link #read(InboundChannelBuffer)}. If the data is application data, it will be decrypted and placed into
|
||||
* the buffer passed as an argument. Otherwise, it will be consumed internally and advance the SSL/TLS close
|
||||
* or handshake process.
|
||||
* {@link #read(InboundChannelBuffer, InboundChannelBuffer)}. If the data is application data, it will be
|
||||
* decrypted and placed into the application buffer passed as an argument. Otherwise, it will be consumed
|
||||
* internally and advance the SSL/TLS close or handshake process.
|
||||
*
|
||||
* Producing writes for a channel is more complicated. The method {@link #needsNonApplicationWrite()} can be
|
||||
* called to determine if this driver needs to produce more data to advance the handshake or close process.
|
||||
|
@ -54,21 +56,22 @@ public class SSLDriver implements AutoCloseable {
|
|||
private static final FlushOperation EMPTY_FLUSH_OPERATION = new FlushOperation(EMPTY_BUFFERS, (r, t) -> {});
|
||||
|
||||
private final SSLEngine engine;
|
||||
// TODO: When the bytes are actually recycled, we need to test that they are released on driver close
|
||||
private final SSLOutboundBuffer outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n)));
|
||||
private final IntFunction<Page> pageAllocator;
|
||||
private final SSLOutboundBuffer outboundBuffer;
|
||||
private Page networkReadPage;
|
||||
private final boolean isClientMode;
|
||||
// This should only be accessed by the network thread associated with this channel, so nothing needs to
|
||||
// be volatile.
|
||||
private Mode currentMode = new HandshakeMode();
|
||||
private ByteBuffer networkReadBuffer;
|
||||
private int packetSize;
|
||||
|
||||
public SSLDriver(SSLEngine engine, boolean isClientMode) {
|
||||
public SSLDriver(SSLEngine engine, IntFunction<Page> pageAllocator, boolean isClientMode) {
|
||||
this.engine = engine;
|
||||
this.pageAllocator = pageAllocator;
|
||||
this.outboundBuffer = new SSLOutboundBuffer(pageAllocator);
|
||||
this.isClientMode = isClientMode;
|
||||
SSLSession session = engine.getSession();
|
||||
packetSize = session.getPacketBufferSize();
|
||||
this.networkReadBuffer = ByteBuffer.allocate(packetSize);
|
||||
}
|
||||
|
||||
public void init() throws SSLException {
|
||||
|
@ -106,22 +109,25 @@ public class SSLDriver implements AutoCloseable {
|
|||
return currentMode.isHandshake();
|
||||
}
|
||||
|
||||
public ByteBuffer getNetworkReadBuffer() {
|
||||
return networkReadBuffer;
|
||||
}
|
||||
|
||||
public SSLOutboundBuffer getOutboundBuffer() {
|
||||
return outboundBuffer;
|
||||
}
|
||||
|
||||
public void read(InboundChannelBuffer buffer) throws SSLException {
|
||||
Mode modePriorToRead;
|
||||
do {
|
||||
modePriorToRead = currentMode;
|
||||
currentMode.read(buffer);
|
||||
// If we switched modes we want to read again as there might be unhandled bytes that need to be
|
||||
// handled by the new mode.
|
||||
} while (modePriorToRead != currentMode);
|
||||
public void read(InboundChannelBuffer encryptedBuffer, InboundChannelBuffer applicationBuffer) throws SSLException {
|
||||
networkReadPage = pageAllocator.apply(packetSize);
|
||||
try {
|
||||
Mode modePriorToRead;
|
||||
do {
|
||||
modePriorToRead = currentMode;
|
||||
currentMode.read(encryptedBuffer, applicationBuffer);
|
||||
// It is possible that we received multiple SSL packets from the network since the last read.
|
||||
// If one of those packets causes us to change modes (such as finished handshaking), we need
|
||||
// to call read in the new mode to handle the remaining packets.
|
||||
} while (modePriorToRead != currentMode);
|
||||
} finally {
|
||||
networkReadPage.close();
|
||||
networkReadPage = null;
|
||||
}
|
||||
}
|
||||
|
||||
public boolean readyForApplicationWrites() {
|
||||
|
@ -171,27 +177,34 @@ public class SSLDriver implements AutoCloseable {
|
|||
ExceptionsHelper.rethrowAndSuppress(closingExceptions);
|
||||
}
|
||||
|
||||
private SSLEngineResult unwrap(InboundChannelBuffer buffer) throws SSLException {
|
||||
private SSLEngineResult unwrap(InboundChannelBuffer networkBuffer, InboundChannelBuffer applicationBuffer) throws SSLException {
|
||||
while (true) {
|
||||
SSLEngineResult result = engine.unwrap(networkReadBuffer, buffer.sliceBuffersFrom(buffer.getIndex()));
|
||||
buffer.incrementIndex(result.bytesProduced());
|
||||
ensureApplicationBufferSize(applicationBuffer);
|
||||
ByteBuffer networkReadBuffer = networkReadPage.byteBuffer();
|
||||
networkReadBuffer.clear();
|
||||
ByteBufferUtils.copyBytes(networkBuffer.sliceBuffersTo(Math.min(networkBuffer.getIndex(), packetSize)), networkReadBuffer);
|
||||
networkReadBuffer.flip();
|
||||
SSLEngineResult result = engine.unwrap(networkReadBuffer, applicationBuffer.sliceBuffersFrom(applicationBuffer.getIndex()));
|
||||
networkBuffer.release(result.bytesConsumed());
|
||||
applicationBuffer.incrementIndex(result.bytesProduced());
|
||||
switch (result.getStatus()) {
|
||||
case OK:
|
||||
networkReadBuffer.compact();
|
||||
return result;
|
||||
case BUFFER_UNDERFLOW:
|
||||
// There is not enough space in the network buffer for an entire SSL packet. Compact the
|
||||
// current data and expand the buffer if necessary.
|
||||
int currentCapacity = networkReadBuffer.capacity();
|
||||
ensureNetworkReadBufferSize();
|
||||
if (currentCapacity == networkReadBuffer.capacity()) {
|
||||
networkReadBuffer.compact();
|
||||
packetSize = engine.getSession().getPacketBufferSize();
|
||||
if (networkReadPage.byteBuffer().capacity() < packetSize) {
|
||||
networkReadPage.close();
|
||||
networkReadPage = pageAllocator.apply(packetSize);
|
||||
} else {
|
||||
return result;
|
||||
}
|
||||
return result;
|
||||
break;
|
||||
case BUFFER_OVERFLOW:
|
||||
// There is not enough space in the application buffer for the decrypted message. Expand
|
||||
// the application buffer to ensure that it has enough space.
|
||||
ensureApplicationBufferSize(buffer);
|
||||
ensureApplicationBufferSize(applicationBuffer);
|
||||
break;
|
||||
case CLOSED:
|
||||
assert engine.isInboundDone() : "We received close_notify so read should be done";
|
||||
|
@ -254,15 +267,6 @@ public class SSLDriver implements AutoCloseable {
|
|||
}
|
||||
}
|
||||
|
||||
private void ensureNetworkReadBufferSize() {
|
||||
packetSize = engine.getSession().getPacketBufferSize();
|
||||
if (networkReadBuffer.capacity() < packetSize) {
|
||||
ByteBuffer newBuffer = ByteBuffer.allocate(packetSize);
|
||||
networkReadBuffer.flip();
|
||||
newBuffer.put(networkReadBuffer);
|
||||
}
|
||||
}
|
||||
|
||||
// There are three potential modes for the driver to be in - HANDSHAKE, APPLICATION, or CLOSE. HANDSHAKE
|
||||
// is the initial mode. During this mode data that is read and written will be related to the TLS
|
||||
// handshake process. Application related data cannot be encrypted until the handshake is complete. From
|
||||
|
@ -282,7 +286,7 @@ public class SSLDriver implements AutoCloseable {
|
|||
|
||||
private interface Mode {
|
||||
|
||||
void read(InboundChannelBuffer buffer) throws SSLException;
|
||||
void read(InboundChannelBuffer encryptedBuffer, InboundChannelBuffer applicationBuffer) throws SSLException;
|
||||
|
||||
int write(FlushOperation applicationBytes) throws SSLException;
|
||||
|
||||
|
@ -342,13 +346,11 @@ public class SSLDriver implements AutoCloseable {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void read(InboundChannelBuffer buffer) throws SSLException {
|
||||
ensureApplicationBufferSize(buffer);
|
||||
public void read(InboundChannelBuffer encryptedBuffer, InboundChannelBuffer applicationBuffer) throws SSLException {
|
||||
boolean continueUnwrap = true;
|
||||
while (continueUnwrap && networkReadBuffer.position() > 0) {
|
||||
networkReadBuffer.flip();
|
||||
while (continueUnwrap && encryptedBuffer.getIndex() > 0) {
|
||||
try {
|
||||
SSLEngineResult result = unwrap(buffer);
|
||||
SSLEngineResult result = unwrap(encryptedBuffer, applicationBuffer);
|
||||
handshakeStatus = result.getHandshakeStatus();
|
||||
handshake();
|
||||
// If we are done handshaking we should exit the handshake read
|
||||
|
@ -430,12 +432,10 @@ public class SSLDriver implements AutoCloseable {
|
|||
private class ApplicationMode implements Mode {
|
||||
|
||||
@Override
|
||||
public void read(InboundChannelBuffer buffer) throws SSLException {
|
||||
ensureApplicationBufferSize(buffer);
|
||||
public void read(InboundChannelBuffer encryptedBuffer, InboundChannelBuffer applicationBuffer) throws SSLException {
|
||||
boolean continueUnwrap = true;
|
||||
while (continueUnwrap && networkReadBuffer.position() > 0) {
|
||||
networkReadBuffer.flip();
|
||||
SSLEngineResult result = unwrap(buffer);
|
||||
while (continueUnwrap && encryptedBuffer.getIndex() > 0) {
|
||||
SSLEngineResult result = unwrap(encryptedBuffer, applicationBuffer);
|
||||
boolean renegotiationRequested = result.getStatus() != SSLEngineResult.Status.CLOSED
|
||||
&& maybeRenegotiation(result.getHandshakeStatus());
|
||||
continueUnwrap = result.bytesProduced() > 0 && renegotiationRequested == false;
|
||||
|
@ -515,7 +515,7 @@ public class SSLDriver implements AutoCloseable {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void read(InboundChannelBuffer buffer) throws SSLException {
|
||||
public void read(InboundChannelBuffer encryptedBuffer, InboundChannelBuffer applicationBuffer) throws SSLException {
|
||||
if (needToReceiveClose == false) {
|
||||
// There is an issue where receiving handshake messages after initiating the close process
|
||||
// can place the SSLEngine back into handshaking mode. In order to handle this, if we
|
||||
|
@ -524,11 +524,9 @@ public class SSLDriver implements AutoCloseable {
|
|||
return;
|
||||
}
|
||||
|
||||
ensureApplicationBufferSize(buffer);
|
||||
boolean continueUnwrap = true;
|
||||
while (continueUnwrap && networkReadBuffer.position() > 0) {
|
||||
networkReadBuffer.flip();
|
||||
SSLEngineResult result = unwrap(buffer);
|
||||
while (continueUnwrap && encryptedBuffer.getIndex() > 0) {
|
||||
SSLEngineResult result = unwrap(encryptedBuffer, applicationBuffer);
|
||||
continueUnwrap = result.bytesProduced() > 0 || result.bytesConsumed() > 0;
|
||||
}
|
||||
if (engine.isInboundDone()) {
|
||||
|
|
|
@ -8,7 +8,6 @@ package org.elasticsearch.xpack.security.transport.nio;
|
|||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.elasticsearch.common.network.NetworkService;
|
||||
import org.elasticsearch.common.recycler.Recycler;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.util.BigArrays;
|
||||
import org.elasticsearch.common.util.PageCacheRecycler;
|
||||
|
@ -22,7 +21,6 @@ import org.elasticsearch.nio.ChannelFactory;
|
|||
import org.elasticsearch.nio.InboundChannelBuffer;
|
||||
import org.elasticsearch.nio.NioSelector;
|
||||
import org.elasticsearch.nio.NioSocketChannel;
|
||||
import org.elasticsearch.nio.Page;
|
||||
import org.elasticsearch.nio.ServerChannelContext;
|
||||
import org.elasticsearch.nio.SocketChannelContext;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
|
@ -35,11 +33,9 @@ import org.elasticsearch.xpack.security.transport.filter.IPFilter;
|
|||
import javax.net.ssl.SSLEngine;
|
||||
import java.io.IOException;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.channels.ServerSocketChannel;
|
||||
import java.nio.channels.SocketChannel;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
import static org.elasticsearch.xpack.core.XPackSettings.HTTP_SSL_ENABLED;
|
||||
|
||||
|
@ -93,13 +89,9 @@ public class SecurityNioHttpServerTransport extends NioHttpServerTransport {
|
|||
@Override
|
||||
public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException {
|
||||
NioHttpChannel httpChannel = new NioHttpChannel(channel);
|
||||
Supplier<Page> pageSupplier = () -> {
|
||||
Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
|
||||
return new Page(ByteBuffer.wrap(bytes.v()), bytes::close);
|
||||
};
|
||||
HttpReadWriteHandler httpHandler = new HttpReadWriteHandler(httpChannel,SecurityNioHttpServerTransport.this,
|
||||
handlingSettings, corsConfig, selector.getTaskScheduler(), threadPool::relativeTimeInNanos);
|
||||
InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier);
|
||||
InboundChannelBuffer networkBuffer = new InboundChannelBuffer(pageAllocator);
|
||||
Consumer<Exception> exceptionHandler = (e) -> securityExceptionHandler.accept(httpChannel, e);
|
||||
|
||||
SocketChannelContext context;
|
||||
|
@ -113,10 +105,12 @@ public class SecurityNioHttpServerTransport extends NioHttpServerTransport {
|
|||
} else {
|
||||
sslEngine = sslService.createSSLEngine(sslConfiguration, null, -1);
|
||||
}
|
||||
SSLDriver sslDriver = new SSLDriver(sslEngine, false);
|
||||
context = new SSLChannelContext(httpChannel, selector, exceptionHandler, sslDriver, httpHandler, buffer, nioIpFilter);
|
||||
SSLDriver sslDriver = new SSLDriver(sslEngine, pageAllocator, false);
|
||||
InboundChannelBuffer applicationBuffer = new InboundChannelBuffer(pageAllocator);
|
||||
context = new SSLChannelContext(httpChannel, selector, exceptionHandler, sslDriver, httpHandler, networkBuffer,
|
||||
applicationBuffer, nioIpFilter);
|
||||
} else {
|
||||
context = new BytesChannelContext(httpChannel, selector, exceptionHandler, httpHandler, buffer, nioIpFilter);
|
||||
context = new BytesChannelContext(httpChannel, selector, exceptionHandler, httpHandler, networkBuffer, nioIpFilter);
|
||||
}
|
||||
httpChannel.setContext(context);
|
||||
|
||||
|
|
|
@ -12,7 +12,6 @@ import org.elasticsearch.cluster.node.DiscoveryNode;
|
|||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.network.NetworkService;
|
||||
import org.elasticsearch.common.recycler.Recycler;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.util.PageCacheRecycler;
|
||||
import org.elasticsearch.indices.breaker.CircuitBreakerService;
|
||||
|
@ -21,7 +20,6 @@ import org.elasticsearch.nio.ChannelFactory;
|
|||
import org.elasticsearch.nio.InboundChannelBuffer;
|
||||
import org.elasticsearch.nio.NioSelector;
|
||||
import org.elasticsearch.nio.NioSocketChannel;
|
||||
import org.elasticsearch.nio.Page;
|
||||
import org.elasticsearch.nio.ServerChannelContext;
|
||||
import org.elasticsearch.nio.SocketChannelContext;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
|
@ -45,14 +43,12 @@ import javax.net.ssl.SSLEngine;
|
|||
import javax.net.ssl.SSLParameters;
|
||||
import java.io.IOException;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.channels.ServerSocketChannel;
|
||||
import java.nio.channels.SocketChannel;
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.function.Function;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
import static org.elasticsearch.xpack.core.security.SecurityField.setting;
|
||||
|
||||
|
@ -156,20 +152,18 @@ public class SecurityNioTransport extends NioTransport {
|
|||
@Override
|
||||
public NioTcpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException {
|
||||
NioTcpChannel nioChannel = new NioTcpChannel(isClient == false, profileName, channel);
|
||||
Supplier<Page> pageSupplier = () -> {
|
||||
Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
|
||||
return new Page(ByteBuffer.wrap(bytes.v()), bytes::close);
|
||||
};
|
||||
TcpReadWriteHandler readWriteHandler = new TcpReadWriteHandler(nioChannel, SecurityNioTransport.this);
|
||||
InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier);
|
||||
InboundChannelBuffer networkBuffer = new InboundChannelBuffer(pageAllocator);
|
||||
Consumer<Exception> exceptionHandler = (e) -> onException(nioChannel, e);
|
||||
|
||||
SocketChannelContext context;
|
||||
if (sslEnabled) {
|
||||
SSLDriver sslDriver = new SSLDriver(createSSLEngine(channel), isClient);
|
||||
context = new SSLChannelContext(nioChannel, selector, exceptionHandler, sslDriver, readWriteHandler, buffer, ipFilter);
|
||||
SSLDriver sslDriver = new SSLDriver(createSSLEngine(channel), pageAllocator, isClient);
|
||||
InboundChannelBuffer applicationBuffer = new InboundChannelBuffer(pageAllocator);
|
||||
context = new SSLChannelContext(nioChannel, selector, exceptionHandler, sslDriver, readWriteHandler, networkBuffer,
|
||||
applicationBuffer, ipFilter);
|
||||
} else {
|
||||
context = new BytesChannelContext(nioChannel, selector, exceptionHandler, readWriteHandler, buffer, ipFilter);
|
||||
context = new BytesChannelContext(nioChannel, selector, exceptionHandler, readWriteHandler, networkBuffer, ipFilter);
|
||||
}
|
||||
nioChannel.setContext(context);
|
||||
|
||||
|
|
|
@ -52,7 +52,6 @@ public class SSLChannelContextTests extends ESTestCase {
|
|||
private BiConsumer<Void, Exception> listener;
|
||||
private Consumer exceptionHandler;
|
||||
private SSLDriver sslDriver;
|
||||
private ByteBuffer readBuffer = ByteBuffer.allocate(1 << 14);
|
||||
private int messageLength;
|
||||
|
||||
@Before
|
||||
|
@ -76,7 +75,6 @@ public class SSLChannelContextTests extends ESTestCase {
|
|||
|
||||
when(selector.isOnCurrentThread()).thenReturn(true);
|
||||
when(selector.getTaskScheduler()).thenReturn(nioTimer);
|
||||
when(sslDriver.getNetworkReadBuffer()).thenReturn(readBuffer);
|
||||
when(sslDriver.getOutboundBuffer()).thenReturn(outboundBuffer);
|
||||
ByteBuffer buffer = ByteBuffer.allocate(1 << 14);
|
||||
when(selector.getIoBuffer()).thenAnswer(invocationOnMock -> {
|
||||
|
@ -88,8 +86,12 @@ public class SSLChannelContextTests extends ESTestCase {
|
|||
public void testSuccessfulRead() throws IOException {
|
||||
byte[] bytes = createMessage(messageLength);
|
||||
|
||||
when(rawChannel.read(any(ByteBuffer.class))).thenReturn(bytes.length);
|
||||
doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer);
|
||||
when(rawChannel.read(any(ByteBuffer.class))).thenAnswer(invocationOnMock -> {
|
||||
ByteBuffer buffer = (ByteBuffer) invocationOnMock.getArguments()[0];
|
||||
buffer.put(bytes);
|
||||
return bytes.length;
|
||||
});
|
||||
doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(any(InboundChannelBuffer.class), eq(channelBuffer));
|
||||
|
||||
when(readConsumer.apply(channelBuffer)).thenReturn(messageLength, 0);
|
||||
|
||||
|
@ -103,8 +105,12 @@ public class SSLChannelContextTests extends ESTestCase {
|
|||
public void testMultipleReadsConsumed() throws IOException {
|
||||
byte[] bytes = createMessage(messageLength * 2);
|
||||
|
||||
when(rawChannel.read(any(ByteBuffer.class))).thenReturn(bytes.length);
|
||||
doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer);
|
||||
when(rawChannel.read(any(ByteBuffer.class))).thenAnswer(invocationOnMock -> {
|
||||
ByteBuffer buffer = (ByteBuffer) invocationOnMock.getArguments()[0];
|
||||
buffer.put(bytes);
|
||||
return bytes.length;
|
||||
});
|
||||
doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(any(InboundChannelBuffer.class), eq(channelBuffer));
|
||||
|
||||
when(readConsumer.apply(channelBuffer)).thenReturn(messageLength, messageLength, 0);
|
||||
|
||||
|
@ -118,8 +124,12 @@ public class SSLChannelContextTests extends ESTestCase {
|
|||
public void testPartialRead() throws IOException {
|
||||
byte[] bytes = createMessage(messageLength);
|
||||
|
||||
when(rawChannel.read(any(ByteBuffer.class))).thenReturn(bytes.length);
|
||||
doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer);
|
||||
when(rawChannel.read(any(ByteBuffer.class))).thenAnswer(invocationOnMock -> {
|
||||
ByteBuffer buffer = (ByteBuffer) invocationOnMock.getArguments()[0];
|
||||
buffer.put(bytes);
|
||||
return bytes.length;
|
||||
});
|
||||
doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(any(InboundChannelBuffer.class), eq(channelBuffer));
|
||||
|
||||
|
||||
when(readConsumer.apply(channelBuffer)).thenReturn(0);
|
||||
|
@ -424,12 +434,12 @@ public class SSLChannelContextTests extends ESTestCase {
|
|||
|
||||
private Answer getReadAnswerForBytes(byte[] bytes) {
|
||||
return invocationOnMock -> {
|
||||
InboundChannelBuffer buffer = (InboundChannelBuffer) invocationOnMock.getArguments()[0];
|
||||
buffer.ensureCapacity(buffer.getIndex() + bytes.length);
|
||||
ByteBuffer[] buffers = buffer.sliceBuffersFrom(buffer.getIndex());
|
||||
InboundChannelBuffer appBuffer = (InboundChannelBuffer) invocationOnMock.getArguments()[1];
|
||||
appBuffer.ensureCapacity(appBuffer.getIndex() + bytes.length);
|
||||
ByteBuffer[] buffers = appBuffer.sliceBuffersFrom(appBuffer.getIndex());
|
||||
assert buffers[0].remaining() > bytes.length;
|
||||
buffers[0].put(bytes);
|
||||
buffer.incrementIndex(bytes.length);
|
||||
appBuffer.incrementIndex(bytes.length);
|
||||
return bytes.length;
|
||||
};
|
||||
}
|
||||
|
|
|
@ -26,14 +26,16 @@ import java.security.SecureRandom;
|
|||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.function.IntFunction;
|
||||
|
||||
public class SSLDriverTests extends ESTestCase {
|
||||
|
||||
private final Supplier<Page> pageSupplier = () -> new Page(ByteBuffer.allocate(1 << 14), () -> {});
|
||||
private InboundChannelBuffer serverBuffer = new InboundChannelBuffer(pageSupplier);
|
||||
private InboundChannelBuffer clientBuffer = new InboundChannelBuffer(pageSupplier);
|
||||
private InboundChannelBuffer genericBuffer = new InboundChannelBuffer(pageSupplier);
|
||||
private final IntFunction<Page> pageAllocator = (n) -> new Page(ByteBuffer.allocate(n), () -> {});
|
||||
|
||||
private final InboundChannelBuffer networkReadBuffer = new InboundChannelBuffer(pageAllocator);
|
||||
private final InboundChannelBuffer applicationBuffer = new InboundChannelBuffer(pageAllocator);
|
||||
private final AtomicInteger openPages = new AtomicInteger(0);
|
||||
|
||||
public void testPingPongAndClose() throws Exception {
|
||||
SSLContext sslContext = getSSLContext();
|
||||
|
@ -44,19 +46,36 @@ public class SSLDriverTests extends ESTestCase {
|
|||
handshake(clientDriver, serverDriver);
|
||||
|
||||
ByteBuffer[] buffers = {ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8))};
|
||||
sendAppData(clientDriver, serverDriver, buffers);
|
||||
serverDriver.read(serverBuffer);
|
||||
assertEquals(ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8)), serverBuffer.sliceBuffersTo(4)[0]);
|
||||
sendAppData(clientDriver, buffers);
|
||||
serverDriver.read(networkReadBuffer, applicationBuffer);
|
||||
assertEquals(ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8)), applicationBuffer.sliceBuffersTo(4)[0]);
|
||||
applicationBuffer.release(4);
|
||||
|
||||
ByteBuffer[] buffers2 = {ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8))};
|
||||
sendAppData(serverDriver, clientDriver, buffers2);
|
||||
clientDriver.read(clientBuffer);
|
||||
assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), clientBuffer.sliceBuffersTo(4)[0]);
|
||||
sendAppData(serverDriver, buffers2);
|
||||
clientDriver.read(networkReadBuffer, applicationBuffer);
|
||||
assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), applicationBuffer.sliceBuffersTo(4)[0]);
|
||||
applicationBuffer.release(4);
|
||||
|
||||
assertFalse(clientDriver.needsNonApplicationWrite());
|
||||
normalClose(clientDriver, serverDriver);
|
||||
}
|
||||
|
||||
public void testDataStoredInOutboundBufferIsClosed() throws Exception {
|
||||
SSLContext sslContext = getSSLContext();
|
||||
|
||||
SSLDriver clientDriver = getDriver(sslContext.createSSLEngine(), true);
|
||||
SSLDriver serverDriver = getDriver(sslContext.createSSLEngine(), false);
|
||||
|
||||
handshake(clientDriver, serverDriver);
|
||||
|
||||
ByteBuffer[] buffers = {ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8))};
|
||||
serverDriver.write(new FlushOperation(buffers, (v, e) -> {}));
|
||||
|
||||
expectThrows(SSLException.class, serverDriver::close);
|
||||
assertEquals(0, openPages.get());
|
||||
}
|
||||
|
||||
public void testRenegotiate() throws Exception {
|
||||
SSLContext sslContext = getSSLContext();
|
||||
|
||||
|
@ -73,9 +92,10 @@ public class SSLDriverTests extends ESTestCase {
|
|||
handshake(clientDriver, serverDriver);
|
||||
|
||||
ByteBuffer[] buffers = {ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8))};
|
||||
sendAppData(clientDriver, serverDriver, buffers);
|
||||
serverDriver.read(serverBuffer);
|
||||
assertEquals(ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8)), serverBuffer.sliceBuffersTo(4)[0]);
|
||||
sendAppData(clientDriver, buffers);
|
||||
serverDriver.read(networkReadBuffer, applicationBuffer);
|
||||
assertEquals(ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8)), applicationBuffer.sliceBuffersTo(4)[0]);
|
||||
applicationBuffer.release(4);
|
||||
|
||||
clientDriver.renegotiate();
|
||||
assertTrue(clientDriver.isHandshaking());
|
||||
|
@ -83,17 +103,20 @@ public class SSLDriverTests extends ESTestCase {
|
|||
|
||||
// This tests that the client driver can still receive data based on the prior handshake
|
||||
ByteBuffer[] buffers2 = {ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8))};
|
||||
sendAppData(serverDriver, clientDriver, buffers2);
|
||||
clientDriver.read(clientBuffer);
|
||||
assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), clientBuffer.sliceBuffersTo(4)[0]);
|
||||
sendAppData(serverDriver, buffers2);
|
||||
clientDriver.read(networkReadBuffer, applicationBuffer);
|
||||
assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), applicationBuffer.sliceBuffersTo(4)[0]);
|
||||
applicationBuffer.release(4);
|
||||
|
||||
handshake(clientDriver, serverDriver, true);
|
||||
sendAppData(clientDriver, serverDriver, buffers);
|
||||
serverDriver.read(serverBuffer);
|
||||
assertEquals(ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8)), serverBuffer.sliceBuffersTo(4)[0]);
|
||||
sendAppData(serverDriver, clientDriver, buffers2);
|
||||
clientDriver.read(clientBuffer);
|
||||
assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), clientBuffer.sliceBuffersTo(4)[0]);
|
||||
sendAppData(clientDriver, buffers);
|
||||
serverDriver.read(networkReadBuffer, applicationBuffer);
|
||||
assertEquals(ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8)), applicationBuffer.sliceBuffersTo(4)[0]);
|
||||
applicationBuffer.release(4);
|
||||
sendAppData(serverDriver, buffers2);
|
||||
clientDriver.read(networkReadBuffer, applicationBuffer);
|
||||
assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), applicationBuffer.sliceBuffersTo(4)[0]);
|
||||
applicationBuffer.release(4);
|
||||
|
||||
normalClose(clientDriver, serverDriver);
|
||||
}
|
||||
|
@ -108,18 +131,22 @@ public class SSLDriverTests extends ESTestCase {
|
|||
|
||||
ByteBuffer buffer = ByteBuffer.allocate(1 << 15);
|
||||
for (int i = 0; i < (1 << 15); ++i) {
|
||||
buffer.put((byte) i);
|
||||
buffer.put((byte) (i % 127));
|
||||
}
|
||||
buffer.flip();
|
||||
ByteBuffer[] buffers = {buffer};
|
||||
sendAppData(clientDriver, serverDriver, buffers);
|
||||
serverDriver.read(serverBuffer);
|
||||
assertEquals(16384, serverBuffer.sliceBuffersFrom(0)[0].limit());
|
||||
assertEquals(16384, serverBuffer.sliceBuffersFrom(0)[1].limit());
|
||||
sendAppData(clientDriver, buffers);
|
||||
serverDriver.read(networkReadBuffer, applicationBuffer);
|
||||
ByteBuffer[] buffers1 = applicationBuffer.sliceBuffersFrom(0);
|
||||
assertEquals((byte) (16383 % 127), buffers1[0].get(16383));
|
||||
assertEquals((byte) (32767 % 127), buffers1[1].get(16383));
|
||||
applicationBuffer.release(1 << 15);
|
||||
|
||||
ByteBuffer[] buffers2 = {ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8))};
|
||||
sendAppData(serverDriver, clientDriver, buffers2);
|
||||
clientDriver.read(clientBuffer);
|
||||
assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), clientBuffer.sliceBuffersTo(4)[0]);
|
||||
sendAppData(serverDriver, buffers2);
|
||||
clientDriver.read(networkReadBuffer, applicationBuffer);
|
||||
assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), applicationBuffer.sliceBuffersTo(4)[0]);
|
||||
applicationBuffer.release(4);
|
||||
|
||||
assertFalse(clientDriver.needsNonApplicationWrite());
|
||||
normalClose(clientDriver, serverDriver);
|
||||
|
@ -193,16 +220,16 @@ public class SSLDriverTests extends ESTestCase {
|
|||
serverDriver.initiateClose();
|
||||
assertTrue(serverDriver.needsNonApplicationWrite());
|
||||
assertFalse(serverDriver.isClosed());
|
||||
sendNonApplicationWrites(serverDriver, clientDriver);
|
||||
sendNonApplicationWrites(serverDriver);
|
||||
// We are immediately fully closed due to SSLEngine inconsistency
|
||||
assertTrue(serverDriver.isClosed());
|
||||
// This should not throw exception yet as the SSLEngine will not UNWRAP data while attempting to WRAP
|
||||
clientDriver.read(clientBuffer);
|
||||
sendNonApplicationWrites(clientDriver, serverDriver);
|
||||
clientDriver.read(clientBuffer);
|
||||
sendNonApplicationWrites(clientDriver, serverDriver);
|
||||
serverDriver.read(serverBuffer);
|
||||
|
||||
SSLException sslException = expectThrows(SSLException.class, () -> clientDriver.read(networkReadBuffer, applicationBuffer));
|
||||
assertEquals("Received close_notify during handshake", sslException.getMessage());
|
||||
sendNonApplicationWrites(clientDriver);
|
||||
assertTrue(clientDriver.isClosed());
|
||||
|
||||
serverDriver.read(networkReadBuffer, applicationBuffer);
|
||||
}
|
||||
|
||||
public void testCloseDuringHandshakePreJDK11() throws Exception {
|
||||
|
@ -226,26 +253,28 @@ public class SSLDriverTests extends ESTestCase {
|
|||
serverDriver.initiateClose();
|
||||
assertTrue(serverDriver.needsNonApplicationWrite());
|
||||
assertFalse(serverDriver.isClosed());
|
||||
sendNonApplicationWrites(serverDriver, clientDriver);
|
||||
sendNonApplicationWrites(serverDriver);
|
||||
// We are immediately fully closed due to SSLEngine inconsistency
|
||||
assertTrue(serverDriver.isClosed());
|
||||
SSLException sslException = expectThrows(SSLException.class, () -> clientDriver.read(clientBuffer));
|
||||
// This should not throw exception yet as the SSLEngine will not UNWRAP data while attempting to WRAP
|
||||
|
||||
SSLException sslException = expectThrows(SSLException.class, () -> clientDriver.read(networkReadBuffer, applicationBuffer));
|
||||
assertEquals("Received close_notify during handshake", sslException.getMessage());
|
||||
assertTrue(clientDriver.needsNonApplicationWrite());
|
||||
sendNonApplicationWrites(clientDriver, serverDriver);
|
||||
serverDriver.read(serverBuffer);
|
||||
sendNonApplicationWrites(clientDriver);
|
||||
assertTrue(clientDriver.isClosed());
|
||||
|
||||
serverDriver.read(networkReadBuffer, applicationBuffer);
|
||||
}
|
||||
|
||||
private void failedCloseAlert(SSLDriver sendDriver, SSLDriver receiveDriver, List<String> messages) throws SSLException {
|
||||
assertTrue(sendDriver.needsNonApplicationWrite());
|
||||
assertFalse(sendDriver.isClosed());
|
||||
|
||||
sendNonApplicationWrites(sendDriver, receiveDriver);
|
||||
sendNonApplicationWrites(sendDriver);
|
||||
assertTrue(sendDriver.isClosed());
|
||||
sendDriver.close();
|
||||
|
||||
SSLException sslException = expectThrows(SSLException.class, () -> receiveDriver.read(genericBuffer));
|
||||
SSLException sslException = expectThrows(SSLException.class, () -> receiveDriver.read(networkReadBuffer, applicationBuffer));
|
||||
assertTrue("Expected one of the following exception messages: " + messages + ". Found: " + sslException.getMessage(),
|
||||
messages.stream().anyMatch(m -> sslException.getMessage().equals(m)));
|
||||
if (receiveDriver.needsNonApplicationWrite() == false) {
|
||||
|
@ -274,29 +303,30 @@ public class SSLDriverTests extends ESTestCase {
|
|||
sendDriver.initiateClose();
|
||||
assertFalse(sendDriver.readyForApplicationWrites());
|
||||
assertTrue(sendDriver.needsNonApplicationWrite());
|
||||
sendNonApplicationWrites(sendDriver, receiveDriver);
|
||||
sendNonApplicationWrites(sendDriver);
|
||||
assertFalse(sendDriver.isClosed());
|
||||
|
||||
receiveDriver.read(genericBuffer);
|
||||
receiveDriver.read(networkReadBuffer, applicationBuffer);
|
||||
assertFalse(receiveDriver.isClosed());
|
||||
|
||||
assertFalse(receiveDriver.readyForApplicationWrites());
|
||||
assertTrue(receiveDriver.needsNonApplicationWrite());
|
||||
sendNonApplicationWrites(receiveDriver, sendDriver);
|
||||
sendNonApplicationWrites(receiveDriver);
|
||||
assertTrue(receiveDriver.isClosed());
|
||||
|
||||
sendDriver.read(genericBuffer);
|
||||
sendDriver.read(networkReadBuffer, applicationBuffer);
|
||||
assertTrue(sendDriver.isClosed());
|
||||
|
||||
sendDriver.close();
|
||||
receiveDriver.close();
|
||||
assertEquals(0, openPages.get());
|
||||
}
|
||||
|
||||
private void sendNonApplicationWrites(SSLDriver sendDriver, SSLDriver receiveDriver) throws SSLException {
|
||||
private void sendNonApplicationWrites(SSLDriver sendDriver) throws SSLException {
|
||||
SSLOutboundBuffer outboundBuffer = sendDriver.getOutboundBuffer();
|
||||
while (sendDriver.needsNonApplicationWrite() || outboundBuffer.hasEncryptedBytesToFlush()) {
|
||||
if (outboundBuffer.hasEncryptedBytesToFlush()) {
|
||||
sendData(outboundBuffer.buildNetworkFlushOperation(), receiveDriver);
|
||||
sendData(outboundBuffer.buildNetworkFlushOperation());
|
||||
} else {
|
||||
sendDriver.nonApplicationWrite();
|
||||
}
|
||||
|
@ -342,8 +372,8 @@ public class SSLDriverTests extends ESTestCase {
|
|||
|
||||
while (sendDriver.needsNonApplicationWrite() || outboundBuffer.hasEncryptedBytesToFlush()) {
|
||||
if (outboundBuffer.hasEncryptedBytesToFlush()) {
|
||||
sendData(outboundBuffer.buildNetworkFlushOperation(), receiveDriver);
|
||||
receiveDriver.read(genericBuffer);
|
||||
sendData(outboundBuffer.buildNetworkFlushOperation());
|
||||
receiveDriver.read(networkReadBuffer, applicationBuffer);
|
||||
} else {
|
||||
sendDriver.nonApplicationWrite();
|
||||
}
|
||||
|
@ -353,37 +383,46 @@ public class SSLDriverTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
private void sendAppData(SSLDriver sendDriver, SSLDriver receiveDriver, ByteBuffer[] message) throws IOException {
|
||||
private void sendAppData(SSLDriver sendDriver, ByteBuffer[] message) throws IOException {
|
||||
assertFalse(sendDriver.needsNonApplicationWrite());
|
||||
|
||||
int bytesToEncrypt = Arrays.stream(message).mapToInt(Buffer::remaining).sum();
|
||||
SSLOutboundBuffer outboundBuffer = sendDriver.getOutboundBuffer();
|
||||
FlushOperation flushOperation = new FlushOperation(message, (r, l) -> {});
|
||||
|
||||
int bytesEncrypted = 0;
|
||||
while (bytesToEncrypt > bytesEncrypted) {
|
||||
bytesEncrypted += sendDriver.write(flushOperation);
|
||||
sendData(outboundBuffer.buildNetworkFlushOperation(), receiveDriver);
|
||||
while (flushOperation.isFullyFlushed() == false) {
|
||||
sendDriver.write(flushOperation);
|
||||
}
|
||||
sendData(sendDriver.getOutboundBuffer().buildNetworkFlushOperation());
|
||||
}
|
||||
|
||||
private void sendData(FlushOperation flushOperation, SSLDriver receiveDriver) {
|
||||
ByteBuffer readBuffer = receiveDriver.getNetworkReadBuffer();
|
||||
private void sendData(FlushOperation flushOperation) {
|
||||
ByteBuffer[] writeBuffers = flushOperation.getBuffersToWrite();
|
||||
int bytesToEncrypt = Arrays.stream(writeBuffers).mapToInt(Buffer::remaining).sum();
|
||||
assert bytesToEncrypt < readBuffer.capacity() : "Flush operation must be less that read buffer";
|
||||
int bytesToCopy = Arrays.stream(writeBuffers).mapToInt(Buffer::remaining).sum();
|
||||
networkReadBuffer.ensureCapacity(bytesToCopy + networkReadBuffer.getIndex());
|
||||
ByteBuffer[] byteBuffers = networkReadBuffer.sliceBuffersFrom(0);
|
||||
assert writeBuffers.length > 0 : "No write buffers";
|
||||
|
||||
for (ByteBuffer writeBuffer : writeBuffers) {
|
||||
int written = writeBuffer.remaining();
|
||||
int r = 0;
|
||||
while (flushOperation.isFullyFlushed() == false) {
|
||||
ByteBuffer readBuffer = byteBuffers[r];
|
||||
ByteBuffer writeBuffer = flushOperation.getBuffersToWrite()[0];
|
||||
int toWrite = Math.min(writeBuffer.remaining(), readBuffer.remaining());
|
||||
writeBuffer.limit(writeBuffer.position() + toWrite);
|
||||
readBuffer.put(writeBuffer);
|
||||
flushOperation.incrementIndex(written);
|
||||
flushOperation.incrementIndex(toWrite);
|
||||
if (readBuffer.remaining() == 0) {
|
||||
r++;
|
||||
}
|
||||
}
|
||||
networkReadBuffer.incrementIndex(bytesToCopy);
|
||||
|
||||
assertTrue(flushOperation.isFullyFlushed());
|
||||
flushOperation.getListener().accept(null, null);
|
||||
}
|
||||
|
||||
private SSLDriver getDriver(SSLEngine engine, boolean isClient) {
|
||||
return new SSLDriver(engine, isClient);
|
||||
return new SSLDriver(engine, (n) -> {
|
||||
openPages.incrementAndGet();
|
||||
return new Page(ByteBuffer.allocate(n), openPages::decrementAndGet);
|
||||
}, isClient);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue