Selectors operate on channel contexts (elastic/x-pack-elasticsearch#3803)

This is related to elastic/elasticsearch#28468. It is a compatibility
commit to ensure that x-pack is compatible with those changes.

Original commit: elastic/x-pack-elasticsearch@3ddf719adc
This commit is contained in:
Tim Brooks 2018-02-22 09:45:07 -07:00 committed by GitHub
parent 531d44f446
commit 33ae455e6c
3 changed files with 102 additions and 69 deletions

View File

@ -5,10 +5,10 @@
*/
package org.elasticsearch.xpack.security.transport.nio;
import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.nio.BytesWriteOperation;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.nio.SocketSelector;
import org.elasticsearch.nio.WriteOperation;
import org.elasticsearch.nio.utils.ExceptionsHelper;
@ -20,6 +20,7 @@ import java.util.ArrayList;
import java.util.LinkedList;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
/**
* Provides a TLS/SSL read/write layer over a channel. This context will use a {@link SSLDriver} to handshake
@ -35,16 +36,17 @@ public final class SSLChannelContext extends SocketChannelContext {
private final InboundChannelBuffer buffer;
private final AtomicBoolean isClosing = new AtomicBoolean(false);
SSLChannelContext(NioSocketChannel channel, BiConsumer<NioSocketChannel, Exception> exceptionHandler, SSLDriver sslDriver,
SSLChannelContext(NioSocketChannel channel, SocketSelector selector, Consumer<Exception> exceptionHandler, SSLDriver sslDriver,
ReadConsumer readConsumer, InboundChannelBuffer buffer) {
super(channel, exceptionHandler);
super(channel, selector, exceptionHandler);
this.sslDriver = sslDriver;
this.readConsumer = readConsumer;
this.buffer = buffer;
}
@Override
public void channelRegistered() throws IOException {
public void register() throws IOException {
super.register();
sslDriver.init();
}
@ -55,8 +57,8 @@ public final class SSLChannelContext extends SocketChannelContext {
return;
}
BytesWriteOperation writeOperation = new BytesWriteOperation(channel, buffers, listener);
SocketSelector selector = channel.getSelector();
BytesWriteOperation writeOperation = new BytesWriteOperation(this, buffers, listener);
SocketSelector selector = getSelector();
if (selector.isOnCurrentThread() == false) {
// If this message is being sent from another thread, we queue the write to be handled by the
// network thread
@ -69,7 +71,7 @@ public final class SSLChannelContext extends SocketChannelContext {
@Override
public void queueWriteOperation(WriteOperation writeOperation) {
channel.getSelector().assertOnSelectorThread();
getSelector().assertOnSelectorThread();
if (writeOperation instanceof CloseNotifyOperation) {
sslDriver.initiateClose();
} else {
@ -100,7 +102,7 @@ public final class SSLChannelContext extends SocketChannelContext {
// sent (as we only get to this point if the write buffer has been fully flushed).
if (currentOperation.isFullyFlushed()) {
queued.removeFirst();
channel.getSelector().executeListener(currentOperation.getListener(), null);
getSelector().executeListener(currentOperation.getListener(), null);
currentOperation = queued.peekFirst();
} else {
try {
@ -115,7 +117,7 @@ public final class SSLChannelContext extends SocketChannelContext {
flushToChannel(sslDriver.getNetworkWriteBuffer());
} catch (IOException e) {
queued.removeFirst();
channel.getSelector().executeFailedListener(currentOperation.getListener(), e);
getSelector().executeFailedListener(currentOperation.getListener(), e);
throw e;
}
}
@ -135,7 +137,7 @@ public final class SSLChannelContext extends SocketChannelContext {
@Override
public boolean hasQueuedWriteOps() {
channel.getSelector().assertOnSelectorThread();
getSelector().assertOnSelectorThread();
if (sslDriver.readyForApplicationWrites()) {
return sslDriver.hasFlushPending() || queued.isEmpty() == false;
} else {
@ -173,8 +175,8 @@ public final class SSLChannelContext extends SocketChannelContext {
@Override
public void closeChannel() {
if (isClosing.compareAndSet(false, true)) {
WriteOperation writeOperation = new CloseNotifyOperation(channel);
SocketSelector selector = channel.getSelector();
WriteOperation writeOperation = new CloseNotifyOperation(this);
SocketSelector selector = getSelector();
if (selector.isOnCurrentThread() == false) {
selector.queueWrite(writeOperation);
return;
@ -185,20 +187,20 @@ public final class SSLChannelContext extends SocketChannelContext {
@Override
public void closeFromSelector() throws IOException {
channel.getSelector().assertOnSelectorThread();
getSelector().assertOnSelectorThread();
if (channel.isOpen()) {
// Set to true in order to reject new writes before queuing with selector
isClosing.set(true);
ArrayList<IOException> closingExceptions = new ArrayList<>(2);
try {
channel.closeFromSelector();
super.closeFromSelector();
} catch (IOException e) {
closingExceptions.add(e);
}
try {
buffer.close();
for (BytesWriteOperation op : queued) {
channel.getSelector().executeFailedListener(op.getListener(), new ClosedChannelException());
getSelector().executeFailedListener(op.getListener(), new ClosedChannelException());
}
queued.clear();
sslDriver.close();
@ -212,10 +214,10 @@ public final class SSLChannelContext extends SocketChannelContext {
private static class CloseNotifyOperation implements WriteOperation {
private static final BiConsumer<Void, Throwable> LISTENER = (v, t) -> {};
private final NioSocketChannel channel;
private final SocketChannelContext channelContext;
private CloseNotifyOperation(NioSocketChannel channel) {
this.channel = channel;
private CloseNotifyOperation(SocketChannelContext channelContext) {
this.channelContext = channelContext;
}
@Override
@ -224,8 +226,8 @@ public final class SSLChannelContext extends SocketChannelContext {
}
@Override
public NioSocketChannel getChannel() {
return channel;
public SocketChannelContext getChannel() {
return channelContext;
}
}
}

View File

@ -38,7 +38,7 @@ import java.nio.channels.SocketChannel;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Supplier;
import static org.elasticsearch.xpack.core.security.SecurityField.setting;
@ -122,7 +122,7 @@ public class SecurityNioTransport extends NioTransport {
SSLConfiguration defaultConfig = profileConfiguration.get(TcpTransport.DEFAULT_PROFILE);
SSLEngine sslEngine = sslService.createSSLEngine(profileConfiguration.getOrDefault(profileName, defaultConfig), null, -1);
SSLDriver sslDriver = new SSLDriver(sslEngine, isClient);
TcpNioSocketChannel nioChannel = new TcpNioSocketChannel(profileName, channel, selector);
TcpNioSocketChannel nioChannel = new TcpNioSocketChannel(profileName, channel);
Supplier<InboundChannelBuffer.Page> pageSupplier = () -> {
Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close);
@ -131,16 +131,18 @@ public class SecurityNioTransport extends NioTransport {
SocketChannelContext.ReadConsumer nioReadConsumer = channelBuffer ->
consumeNetworkReads(nioChannel, BytesReference.fromByteBuffers(channelBuffer.sliceBuffersTo(channelBuffer.getIndex())));
InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier);
BiConsumer<NioSocketChannel, Exception> exceptionHandler = SecurityNioTransport.this::exceptionCaught;
SSLChannelContext context = new SSLChannelContext(nioChannel, exceptionHandler, sslDriver, nioReadConsumer, buffer);
Consumer<Exception> exceptionHandler = (e) -> exceptionCaught(nioChannel, e);
SSLChannelContext context = new SSLChannelContext(nioChannel, selector, exceptionHandler, sslDriver, nioReadConsumer,
buffer);
nioChannel.setContext(context);
return nioChannel;
}
@Override
public TcpNioServerSocketChannel createServerChannel(AcceptingSelector selector, ServerSocketChannel channel) throws IOException {
TcpNioServerSocketChannel nioChannel = new TcpNioServerSocketChannel(profileName, channel, this, selector);
ServerChannelContext context = new ServerChannelContext(nioChannel, SecurityNioTransport.this::acceptChannel, (c, e) -> {});
TcpNioServerSocketChannel nioChannel = new TcpNioServerSocketChannel(profileName, channel);
ServerChannelContext context = new ServerChannelContext(nioChannel, this, selector, SecurityNioTransport.this::acceptChannel,
(e) -> {});
nioChannel.setContext(context);
return nioChannel;
}

View File

@ -6,10 +6,10 @@
package org.elasticsearch.xpack.security.transport.nio;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.nio.BytesWriteOperation;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.nio.SocketSelector;
import org.elasticsearch.nio.WriteOperation;
import org.elasticsearch.test.ESTestCase;
@ -20,8 +20,11 @@ import org.mockito.stubbing.Answer;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Supplier;
import static org.mockito.Matchers.any;
@ -37,10 +40,12 @@ public class SSLChannelContextTests extends ESTestCase {
private SocketChannelContext.ReadConsumer readConsumer;
private NioSocketChannel channel;
private SocketChannel rawChannel;
private SSLChannelContext context;
private InboundChannelBuffer channelBuffer;
private SocketSelector selector;
private BiConsumer<Void, Throwable> listener;
private Consumer exceptionHandler;
private SSLDriver sslDriver;
private ByteBuffer readBuffer = ByteBuffer.allocate(1 << 14);
private ByteBuffer writeBuffer = ByteBuffer.allocate(1 << 14);
@ -55,11 +60,13 @@ public class SSLChannelContextTests extends ESTestCase {
selector = mock(SocketSelector.class);
listener = mock(BiConsumer.class);
channel = mock(NioSocketChannel.class);
rawChannel = mock(SocketChannel.class);
sslDriver = mock(SSLDriver.class);
channelBuffer = InboundChannelBuffer.allocatingInstance();
context = new SSLChannelContext(channel, mock(BiConsumer.class), sslDriver, readConsumer, channelBuffer);
when(channel.getRawChannel()).thenReturn(rawChannel);
exceptionHandler = mock(Consumer.class);
context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readConsumer, channelBuffer);
when(channel.getSelector()).thenReturn(selector);
when(selector.isOnCurrentThread()).thenReturn(true);
when(sslDriver.getNetworkReadBuffer()).thenReturn(readBuffer);
when(sslDriver.getNetworkWriteBuffer()).thenReturn(writeBuffer);
@ -68,7 +75,7 @@ public class SSLChannelContextTests extends ESTestCase {
public void testSuccessfulRead() throws IOException {
byte[] bytes = createMessage(messageLength);
when(channel.read(same(readBuffer))).thenReturn(bytes.length);
when(rawChannel.read(same(readBuffer))).thenReturn(bytes.length);
doAnswer(getAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer);
when(readConsumer.consumeReads(channelBuffer)).thenReturn(messageLength, 0);
@ -83,7 +90,7 @@ public class SSLChannelContextTests extends ESTestCase {
public void testMultipleReadsConsumed() throws IOException {
byte[] bytes = createMessage(messageLength * 2);
when(channel.read(same(readBuffer))).thenReturn(bytes.length);
when(rawChannel.read(same(readBuffer))).thenReturn(bytes.length);
doAnswer(getAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer);
when(readConsumer.consumeReads(channelBuffer)).thenReturn(messageLength, messageLength, 0);
@ -98,7 +105,7 @@ public class SSLChannelContextTests extends ESTestCase {
public void testPartialRead() throws IOException {
byte[] bytes = createMessage(messageLength);
when(channel.read(same(readBuffer))).thenReturn(bytes.length);
when(rawChannel.read(same(readBuffer))).thenReturn(bytes.length);
doAnswer(getAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer);
@ -120,14 +127,14 @@ public class SSLChannelContextTests extends ESTestCase {
public void testReadThrowsIOException() throws IOException {
IOException ioException = new IOException();
when(channel.read(any(ByteBuffer.class))).thenThrow(ioException);
when(rawChannel.read(any(ByteBuffer.class))).thenThrow(ioException);
IOException ex = expectThrows(IOException.class, () -> context.read());
assertSame(ioException, ex);
}
public void testReadThrowsIOExceptionMeansReadyForClose() throws IOException {
when(channel.read(any(ByteBuffer.class))).thenThrow(new IOException());
when(rawChannel.read(any(ByteBuffer.class))).thenThrow(new IOException());
assertFalse(context.selectorShouldClose());
expectThrows(IOException.class, () -> context.read());
@ -135,7 +142,7 @@ public class SSLChannelContextTests extends ESTestCase {
}
public void testReadLessThanZeroMeansReadyForClose() throws IOException {
when(channel.read(any(ByteBuffer.class))).thenReturn(-1);
when(rawChannel.read(any(ByteBuffer.class))).thenReturn(-1);
assertEquals(0, context.read());
@ -144,39 +151,53 @@ public class SSLChannelContextTests extends ESTestCase {
@SuppressWarnings("unchecked")
public void testCloseClosesChannelBuffer() throws IOException {
AtomicInteger closeCount = new AtomicInteger(0);
Supplier<InboundChannelBuffer.Page> pageSupplier = () -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14),
closeCount::incrementAndGet);
InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier);
buffer.ensureCapacity(1);
SSLChannelContext context = new SSLChannelContext(channel, mock(BiConsumer.class), sslDriver, readConsumer, buffer);
when(channel.isOpen()).thenReturn(true);
context.closeFromSelector();
assertEquals(1, closeCount.get());
try (SocketChannel realChannel = SocketChannel.open()) {
when(channel.getRawChannel()).thenReturn(realChannel);
AtomicInteger closeCount = new AtomicInteger(0);
Supplier<InboundChannelBuffer.Page> pageSupplier = () -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14),
closeCount::incrementAndGet);
InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier);
buffer.ensureCapacity(1);
SSLChannelContext context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readConsumer, buffer);
when(channel.isOpen()).thenReturn(true);
context.closeFromSelector();
assertEquals(1, closeCount.get());
}
}
@SuppressWarnings("unchecked")
public void testWriteOpsClearedOnClose() throws IOException {
assertFalse(context.hasQueuedWriteOps());
try (SocketChannel realChannel = SocketChannel.open()) {
when(channel.getRawChannel()).thenReturn(realChannel);
context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readConsumer, channelBuffer);
assertFalse(context.hasQueuedWriteOps());
ByteBuffer[] buffer = {ByteBuffer.allocate(10)};
context.queueWriteOperation(new BytesWriteOperation(channel, buffer, listener));
ByteBuffer[] buffer = {ByteBuffer.allocate(10)};
context.queueWriteOperation(new BytesWriteOperation(context, buffer, listener));
when(sslDriver.readyForApplicationWrites()).thenReturn(true);
assertTrue(context.hasQueuedWriteOps());
when(sslDriver.readyForApplicationWrites()).thenReturn(true);
assertTrue(context.hasQueuedWriteOps());
when(channel.isOpen()).thenReturn(true);
context.closeFromSelector();
when(channel.isOpen()).thenReturn(true);
context.closeFromSelector();
verify(selector).executeFailedListener(same(listener), any(ClosedChannelException.class));
verify(selector).executeFailedListener(same(listener), any(ClosedChannelException.class));
assertFalse(context.hasQueuedWriteOps());
assertFalse(context.hasQueuedWriteOps());
}
}
@SuppressWarnings("unchecked")
public void testSSLDriverClosedOnClose() throws IOException {
when(channel.isOpen()).thenReturn(true);
context.closeFromSelector();
try (SocketChannel realChannel = SocketChannel.open()) {
when(channel.getRawChannel()).thenReturn(realChannel);
context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readConsumer, channelBuffer);
when(channel.isOpen()).thenReturn(true);
context.closeFromSelector();
verify(sslDriver).close();
verify(sslDriver).close();
}
}
public void testWriteFailsIfClosing() {
@ -200,7 +221,7 @@ public class SSLChannelContextTests extends ESTestCase {
BytesWriteOperation writeOp = writeOpCaptor.getValue();
assertSame(listener, writeOp.getListener());
assertSame(channel, writeOp.getChannel());
assertSame(context, writeOp.getChannel());
assertEquals(buffers[0], writeOp.getBuffersToWrite()[0]);
}
@ -214,7 +235,7 @@ public class SSLChannelContextTests extends ESTestCase {
BytesWriteOperation writeOp = writeOpCaptor.getValue();
assertSame(listener, writeOp.getListener());
assertSame(channel, writeOp.getChannel());
assertSame(context, writeOp.getChannel());
assertEquals(buffers[0], writeOp.getBuffersToWrite()[0]);
}
@ -225,7 +246,7 @@ public class SSLChannelContextTests extends ESTestCase {
assertFalse(context.hasQueuedWriteOps());
ByteBuffer[] buffer = {ByteBuffer.allocate(10)};
context.queueWriteOperation(new BytesWriteOperation(channel, buffer, listener));
context.queueWriteOperation(new BytesWriteOperation(context, buffer, listener));
assertTrue(context.hasQueuedWriteOps());
}
@ -236,7 +257,7 @@ public class SSLChannelContextTests extends ESTestCase {
when(sslDriver.needsNonApplicationWrite()).thenReturn(false);
ByteBuffer[] buffer = {ByteBuffer.allocate(10)};
context.queueWriteOperation(new BytesWriteOperation(channel, buffer, listener));
context.queueWriteOperation(new BytesWriteOperation(context, buffer, listener));
assertFalse(context.hasQueuedWriteOps());
}
@ -283,7 +304,7 @@ public class SSLChannelContextTests extends ESTestCase {
context.flushChannel();
verify(sslDriver, times(2)).nonApplicationWrite();
verify(channel, times(2)).write(sslDriver.getNetworkWriteBuffer());
verify(rawChannel, times(2)).write(sslDriver.getNetworkWriteBuffer());
}
public void testNonAppWritesStopIfBufferNotFullyFlushed() throws Exception {
@ -294,7 +315,7 @@ public class SSLChannelContextTests extends ESTestCase {
context.flushChannel();
verify(sslDriver, times(1)).nonApplicationWrite();
verify(channel, times(1)).write(sslDriver.getNetworkWriteBuffer());
verify(rawChannel, times(1)).write(sslDriver.getNetworkWriteBuffer());
}
public void testQueuedWriteIsFlushedInFlushCall() throws Exception {
@ -311,7 +332,7 @@ public class SSLChannelContextTests extends ESTestCase {
context.flushChannel();
verify(writeOperation).incrementIndex(10);
verify(channel, times(1)).write(sslDriver.getNetworkWriteBuffer());
verify(rawChannel, times(1)).write(sslDriver.getNetworkWriteBuffer());
verify(selector).executeListener(listener, null);
assertFalse(context.hasQueuedWriteOps());
}
@ -330,7 +351,7 @@ public class SSLChannelContextTests extends ESTestCase {
context.flushChannel();
verify(writeOperation).incrementIndex(5);
verify(channel, times(1)).write(sslDriver.getNetworkWriteBuffer());
verify(rawChannel, times(1)).write(sslDriver.getNetworkWriteBuffer());
verify(selector, times(0)).executeListener(listener, null);
assertTrue(context.hasQueuedWriteOps());
}
@ -358,7 +379,7 @@ public class SSLChannelContextTests extends ESTestCase {
context.flushChannel();
verify(writeOperation1, times(2)).incrementIndex(5);
verify(channel, times(3)).write(sslDriver.getNetworkWriteBuffer());
verify(rawChannel, times(3)).write(sslDriver.getNetworkWriteBuffer());
verify(selector).executeListener(listener, null);
verify(selector, times(0)).executeListener(listener2, null);
assertTrue(context.hasQueuedWriteOps());
@ -375,7 +396,7 @@ public class SSLChannelContextTests extends ESTestCase {
when(sslDriver.hasFlushPending()).thenReturn(false, false);
when(sslDriver.readyForApplicationWrites()).thenReturn(true);
when(sslDriver.applicationWrite(buffers)).thenReturn(5);
when(channel.write(sslDriver.getNetworkWriteBuffer())).thenThrow(exception);
when(rawChannel.write(sslDriver.getNetworkWriteBuffer())).thenThrow(exception);
when(writeOperation.isFullyFlushed()).thenReturn(false);
expectThrows(IOException.class, () -> context.flushChannel());
@ -388,7 +409,7 @@ public class SSLChannelContextTests extends ESTestCase {
when(sslDriver.hasFlushPending()).thenReturn(true);
when(sslDriver.needsNonApplicationWrite()).thenReturn(true);
when(sslDriver.readyForApplicationWrites()).thenReturn(false);
when(channel.write(sslDriver.getNetworkWriteBuffer())).thenThrow(new IOException());
when(rawChannel.write(sslDriver.getNetworkWriteBuffer())).thenThrow(new IOException());
assertFalse(context.selectorShouldClose());
expectThrows(IOException.class, () -> context.flushChannel());
@ -422,9 +443,17 @@ public class SSLChannelContextTests extends ESTestCase {
verify(sslDriver).initiateClose();
}
@SuppressWarnings("unchecked")
public void testRegisterInitiatesDriver() throws IOException {
context.channelRegistered();
verify(sslDriver).init();
try (Selector realSelector = Selector.open();
SocketChannel realSocket = SocketChannel.open()) {
realSocket.configureBlocking(false);
when(selector.rawSelector()).thenReturn(realSelector);
when(channel.getRawChannel()).thenReturn(realSocket);
context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readConsumer, channelBuffer);
context.register();
verify(sslDriver).init();
}
}
private Answer getAnswerForBytes(byte[] bytes) {