diff --git a/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java b/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java index 28afa151d80..a4e88ec70f2 100644 --- a/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java +++ b/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java @@ -5,10 +5,10 @@ */ package org.elasticsearch.xpack.security.transport.nio; -import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.nio.BytesWriteOperation; import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioSocketChannel; +import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.nio.SocketSelector; import org.elasticsearch.nio.WriteOperation; import org.elasticsearch.nio.utils.ExceptionsHelper; @@ -20,6 +20,7 @@ import java.util.ArrayList; import java.util.LinkedList; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.BiConsumer; +import java.util.function.Consumer; /** * Provides a TLS/SSL read/write layer over a channel. This context will use a {@link SSLDriver} to handshake @@ -35,16 +36,17 @@ public final class SSLChannelContext extends SocketChannelContext { private final InboundChannelBuffer buffer; private final AtomicBoolean isClosing = new AtomicBoolean(false); - SSLChannelContext(NioSocketChannel channel, BiConsumer exceptionHandler, SSLDriver sslDriver, + SSLChannelContext(NioSocketChannel channel, SocketSelector selector, Consumer exceptionHandler, SSLDriver sslDriver, ReadConsumer readConsumer, InboundChannelBuffer buffer) { - super(channel, exceptionHandler); + super(channel, selector, exceptionHandler); this.sslDriver = sslDriver; this.readConsumer = readConsumer; this.buffer = buffer; } @Override - public void channelRegistered() throws IOException { + public void register() throws IOException { + super.register(); sslDriver.init(); } @@ -55,8 +57,8 @@ public final class SSLChannelContext extends SocketChannelContext { return; } - BytesWriteOperation writeOperation = new BytesWriteOperation(channel, buffers, listener); - SocketSelector selector = channel.getSelector(); + BytesWriteOperation writeOperation = new BytesWriteOperation(this, buffers, listener); + SocketSelector selector = getSelector(); if (selector.isOnCurrentThread() == false) { // If this message is being sent from another thread, we queue the write to be handled by the // network thread @@ -69,7 +71,7 @@ public final class SSLChannelContext extends SocketChannelContext { @Override public void queueWriteOperation(WriteOperation writeOperation) { - channel.getSelector().assertOnSelectorThread(); + getSelector().assertOnSelectorThread(); if (writeOperation instanceof CloseNotifyOperation) { sslDriver.initiateClose(); } else { @@ -100,7 +102,7 @@ public final class SSLChannelContext extends SocketChannelContext { // sent (as we only get to this point if the write buffer has been fully flushed). if (currentOperation.isFullyFlushed()) { queued.removeFirst(); - channel.getSelector().executeListener(currentOperation.getListener(), null); + getSelector().executeListener(currentOperation.getListener(), null); currentOperation = queued.peekFirst(); } else { try { @@ -115,7 +117,7 @@ public final class SSLChannelContext extends SocketChannelContext { flushToChannel(sslDriver.getNetworkWriteBuffer()); } catch (IOException e) { queued.removeFirst(); - channel.getSelector().executeFailedListener(currentOperation.getListener(), e); + getSelector().executeFailedListener(currentOperation.getListener(), e); throw e; } } @@ -135,7 +137,7 @@ public final class SSLChannelContext extends SocketChannelContext { @Override public boolean hasQueuedWriteOps() { - channel.getSelector().assertOnSelectorThread(); + getSelector().assertOnSelectorThread(); if (sslDriver.readyForApplicationWrites()) { return sslDriver.hasFlushPending() || queued.isEmpty() == false; } else { @@ -173,8 +175,8 @@ public final class SSLChannelContext extends SocketChannelContext { @Override public void closeChannel() { if (isClosing.compareAndSet(false, true)) { - WriteOperation writeOperation = new CloseNotifyOperation(channel); - SocketSelector selector = channel.getSelector(); + WriteOperation writeOperation = new CloseNotifyOperation(this); + SocketSelector selector = getSelector(); if (selector.isOnCurrentThread() == false) { selector.queueWrite(writeOperation); return; @@ -185,20 +187,20 @@ public final class SSLChannelContext extends SocketChannelContext { @Override public void closeFromSelector() throws IOException { - channel.getSelector().assertOnSelectorThread(); + getSelector().assertOnSelectorThread(); if (channel.isOpen()) { // Set to true in order to reject new writes before queuing with selector isClosing.set(true); ArrayList closingExceptions = new ArrayList<>(2); try { - channel.closeFromSelector(); + super.closeFromSelector(); } catch (IOException e) { closingExceptions.add(e); } try { buffer.close(); for (BytesWriteOperation op : queued) { - channel.getSelector().executeFailedListener(op.getListener(), new ClosedChannelException()); + getSelector().executeFailedListener(op.getListener(), new ClosedChannelException()); } queued.clear(); sslDriver.close(); @@ -212,10 +214,10 @@ public final class SSLChannelContext extends SocketChannelContext { private static class CloseNotifyOperation implements WriteOperation { private static final BiConsumer LISTENER = (v, t) -> {}; - private final NioSocketChannel channel; + private final SocketChannelContext channelContext; - private CloseNotifyOperation(NioSocketChannel channel) { - this.channel = channel; + private CloseNotifyOperation(SocketChannelContext channelContext) { + this.channelContext = channelContext; } @Override @@ -224,8 +226,8 @@ public final class SSLChannelContext extends SocketChannelContext { } @Override - public NioSocketChannel getChannel() { - return channel; + public SocketChannelContext getChannel() { + return channelContext; } } } diff --git a/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java b/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java index c3a139ba084..7773404762e 100644 --- a/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java +++ b/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java @@ -38,7 +38,7 @@ import java.nio.channels.SocketChannel; import java.util.Collections; import java.util.HashMap; import java.util.Map; -import java.util.function.BiConsumer; +import java.util.function.Consumer; import java.util.function.Supplier; import static org.elasticsearch.xpack.core.security.SecurityField.setting; @@ -122,7 +122,7 @@ public class SecurityNioTransport extends NioTransport { SSLConfiguration defaultConfig = profileConfiguration.get(TcpTransport.DEFAULT_PROFILE); SSLEngine sslEngine = sslService.createSSLEngine(profileConfiguration.getOrDefault(profileName, defaultConfig), null, -1); SSLDriver sslDriver = new SSLDriver(sslEngine, isClient); - TcpNioSocketChannel nioChannel = new TcpNioSocketChannel(profileName, channel, selector); + TcpNioSocketChannel nioChannel = new TcpNioSocketChannel(profileName, channel); Supplier pageSupplier = () -> { Recycler.V bytes = pageCacheRecycler.bytePage(false); return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close); @@ -131,16 +131,18 @@ public class SecurityNioTransport extends NioTransport { SocketChannelContext.ReadConsumer nioReadConsumer = channelBuffer -> consumeNetworkReads(nioChannel, BytesReference.fromByteBuffers(channelBuffer.sliceBuffersTo(channelBuffer.getIndex()))); InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier); - BiConsumer exceptionHandler = SecurityNioTransport.this::exceptionCaught; - SSLChannelContext context = new SSLChannelContext(nioChannel, exceptionHandler, sslDriver, nioReadConsumer, buffer); + Consumer exceptionHandler = (e) -> exceptionCaught(nioChannel, e); + SSLChannelContext context = new SSLChannelContext(nioChannel, selector, exceptionHandler, sslDriver, nioReadConsumer, + buffer); nioChannel.setContext(context); return nioChannel; } @Override public TcpNioServerSocketChannel createServerChannel(AcceptingSelector selector, ServerSocketChannel channel) throws IOException { - TcpNioServerSocketChannel nioChannel = new TcpNioServerSocketChannel(profileName, channel, this, selector); - ServerChannelContext context = new ServerChannelContext(nioChannel, SecurityNioTransport.this::acceptChannel, (c, e) -> {}); + TcpNioServerSocketChannel nioChannel = new TcpNioServerSocketChannel(profileName, channel); + ServerChannelContext context = new ServerChannelContext(nioChannel, this, selector, SecurityNioTransport.this::acceptChannel, + (e) -> {}); nioChannel.setContext(context); return nioChannel; } diff --git a/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java b/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java index a6d702df89a..884b348721f 100644 --- a/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java +++ b/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java @@ -6,10 +6,10 @@ package org.elasticsearch.xpack.security.transport.nio; import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.nio.BytesWriteOperation; import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioSocketChannel; +import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.nio.SocketSelector; import org.elasticsearch.nio.WriteOperation; import org.elasticsearch.test.ESTestCase; @@ -20,8 +20,11 @@ import org.mockito.stubbing.Answer; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.ClosedChannelException; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiConsumer; +import java.util.function.Consumer; import java.util.function.Supplier; import static org.mockito.Matchers.any; @@ -37,10 +40,12 @@ public class SSLChannelContextTests extends ESTestCase { private SocketChannelContext.ReadConsumer readConsumer; private NioSocketChannel channel; + private SocketChannel rawChannel; private SSLChannelContext context; private InboundChannelBuffer channelBuffer; private SocketSelector selector; private BiConsumer listener; + private Consumer exceptionHandler; private SSLDriver sslDriver; private ByteBuffer readBuffer = ByteBuffer.allocate(1 << 14); private ByteBuffer writeBuffer = ByteBuffer.allocate(1 << 14); @@ -55,11 +60,13 @@ public class SSLChannelContextTests extends ESTestCase { selector = mock(SocketSelector.class); listener = mock(BiConsumer.class); channel = mock(NioSocketChannel.class); + rawChannel = mock(SocketChannel.class); sslDriver = mock(SSLDriver.class); channelBuffer = InboundChannelBuffer.allocatingInstance(); - context = new SSLChannelContext(channel, mock(BiConsumer.class), sslDriver, readConsumer, channelBuffer); + when(channel.getRawChannel()).thenReturn(rawChannel); + exceptionHandler = mock(Consumer.class); + context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readConsumer, channelBuffer); - when(channel.getSelector()).thenReturn(selector); when(selector.isOnCurrentThread()).thenReturn(true); when(sslDriver.getNetworkReadBuffer()).thenReturn(readBuffer); when(sslDriver.getNetworkWriteBuffer()).thenReturn(writeBuffer); @@ -68,7 +75,7 @@ public class SSLChannelContextTests extends ESTestCase { public void testSuccessfulRead() throws IOException { byte[] bytes = createMessage(messageLength); - when(channel.read(same(readBuffer))).thenReturn(bytes.length); + when(rawChannel.read(same(readBuffer))).thenReturn(bytes.length); doAnswer(getAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer); when(readConsumer.consumeReads(channelBuffer)).thenReturn(messageLength, 0); @@ -83,7 +90,7 @@ public class SSLChannelContextTests extends ESTestCase { public void testMultipleReadsConsumed() throws IOException { byte[] bytes = createMessage(messageLength * 2); - when(channel.read(same(readBuffer))).thenReturn(bytes.length); + when(rawChannel.read(same(readBuffer))).thenReturn(bytes.length); doAnswer(getAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer); when(readConsumer.consumeReads(channelBuffer)).thenReturn(messageLength, messageLength, 0); @@ -98,7 +105,7 @@ public class SSLChannelContextTests extends ESTestCase { public void testPartialRead() throws IOException { byte[] bytes = createMessage(messageLength); - when(channel.read(same(readBuffer))).thenReturn(bytes.length); + when(rawChannel.read(same(readBuffer))).thenReturn(bytes.length); doAnswer(getAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer); @@ -120,14 +127,14 @@ public class SSLChannelContextTests extends ESTestCase { public void testReadThrowsIOException() throws IOException { IOException ioException = new IOException(); - when(channel.read(any(ByteBuffer.class))).thenThrow(ioException); + when(rawChannel.read(any(ByteBuffer.class))).thenThrow(ioException); IOException ex = expectThrows(IOException.class, () -> context.read()); assertSame(ioException, ex); } public void testReadThrowsIOExceptionMeansReadyForClose() throws IOException { - when(channel.read(any(ByteBuffer.class))).thenThrow(new IOException()); + when(rawChannel.read(any(ByteBuffer.class))).thenThrow(new IOException()); assertFalse(context.selectorShouldClose()); expectThrows(IOException.class, () -> context.read()); @@ -135,7 +142,7 @@ public class SSLChannelContextTests extends ESTestCase { } public void testReadLessThanZeroMeansReadyForClose() throws IOException { - when(channel.read(any(ByteBuffer.class))).thenReturn(-1); + when(rawChannel.read(any(ByteBuffer.class))).thenReturn(-1); assertEquals(0, context.read()); @@ -144,39 +151,53 @@ public class SSLChannelContextTests extends ESTestCase { @SuppressWarnings("unchecked") public void testCloseClosesChannelBuffer() throws IOException { - AtomicInteger closeCount = new AtomicInteger(0); - Supplier pageSupplier = () -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14), - closeCount::incrementAndGet); - InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier); - buffer.ensureCapacity(1); - SSLChannelContext context = new SSLChannelContext(channel, mock(BiConsumer.class), sslDriver, readConsumer, buffer); - when(channel.isOpen()).thenReturn(true); - context.closeFromSelector(); - assertEquals(1, closeCount.get()); + try (SocketChannel realChannel = SocketChannel.open()) { + when(channel.getRawChannel()).thenReturn(realChannel); + + AtomicInteger closeCount = new AtomicInteger(0); + Supplier pageSupplier = () -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14), + closeCount::incrementAndGet); + InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier); + buffer.ensureCapacity(1); + SSLChannelContext context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readConsumer, buffer); + when(channel.isOpen()).thenReturn(true); + context.closeFromSelector(); + assertEquals(1, closeCount.get()); + } } + @SuppressWarnings("unchecked") public void testWriteOpsClearedOnClose() throws IOException { - assertFalse(context.hasQueuedWriteOps()); + try (SocketChannel realChannel = SocketChannel.open()) { + when(channel.getRawChannel()).thenReturn(realChannel); + context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readConsumer, channelBuffer); + assertFalse(context.hasQueuedWriteOps()); - ByteBuffer[] buffer = {ByteBuffer.allocate(10)}; - context.queueWriteOperation(new BytesWriteOperation(channel, buffer, listener)); + ByteBuffer[] buffer = {ByteBuffer.allocate(10)}; + context.queueWriteOperation(new BytesWriteOperation(context, buffer, listener)); - when(sslDriver.readyForApplicationWrites()).thenReturn(true); - assertTrue(context.hasQueuedWriteOps()); + when(sslDriver.readyForApplicationWrites()).thenReturn(true); + assertTrue(context.hasQueuedWriteOps()); - when(channel.isOpen()).thenReturn(true); - context.closeFromSelector(); + when(channel.isOpen()).thenReturn(true); + context.closeFromSelector(); - verify(selector).executeFailedListener(same(listener), any(ClosedChannelException.class)); + verify(selector).executeFailedListener(same(listener), any(ClosedChannelException.class)); - assertFalse(context.hasQueuedWriteOps()); + assertFalse(context.hasQueuedWriteOps()); + } } + @SuppressWarnings("unchecked") public void testSSLDriverClosedOnClose() throws IOException { - when(channel.isOpen()).thenReturn(true); - context.closeFromSelector(); + try (SocketChannel realChannel = SocketChannel.open()) { + when(channel.getRawChannel()).thenReturn(realChannel); + context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readConsumer, channelBuffer); + when(channel.isOpen()).thenReturn(true); + context.closeFromSelector(); - verify(sslDriver).close(); + verify(sslDriver).close(); + } } public void testWriteFailsIfClosing() { @@ -200,7 +221,7 @@ public class SSLChannelContextTests extends ESTestCase { BytesWriteOperation writeOp = writeOpCaptor.getValue(); assertSame(listener, writeOp.getListener()); - assertSame(channel, writeOp.getChannel()); + assertSame(context, writeOp.getChannel()); assertEquals(buffers[0], writeOp.getBuffersToWrite()[0]); } @@ -214,7 +235,7 @@ public class SSLChannelContextTests extends ESTestCase { BytesWriteOperation writeOp = writeOpCaptor.getValue(); assertSame(listener, writeOp.getListener()); - assertSame(channel, writeOp.getChannel()); + assertSame(context, writeOp.getChannel()); assertEquals(buffers[0], writeOp.getBuffersToWrite()[0]); } @@ -225,7 +246,7 @@ public class SSLChannelContextTests extends ESTestCase { assertFalse(context.hasQueuedWriteOps()); ByteBuffer[] buffer = {ByteBuffer.allocate(10)}; - context.queueWriteOperation(new BytesWriteOperation(channel, buffer, listener)); + context.queueWriteOperation(new BytesWriteOperation(context, buffer, listener)); assertTrue(context.hasQueuedWriteOps()); } @@ -236,7 +257,7 @@ public class SSLChannelContextTests extends ESTestCase { when(sslDriver.needsNonApplicationWrite()).thenReturn(false); ByteBuffer[] buffer = {ByteBuffer.allocate(10)}; - context.queueWriteOperation(new BytesWriteOperation(channel, buffer, listener)); + context.queueWriteOperation(new BytesWriteOperation(context, buffer, listener)); assertFalse(context.hasQueuedWriteOps()); } @@ -283,7 +304,7 @@ public class SSLChannelContextTests extends ESTestCase { context.flushChannel(); verify(sslDriver, times(2)).nonApplicationWrite(); - verify(channel, times(2)).write(sslDriver.getNetworkWriteBuffer()); + verify(rawChannel, times(2)).write(sslDriver.getNetworkWriteBuffer()); } public void testNonAppWritesStopIfBufferNotFullyFlushed() throws Exception { @@ -294,7 +315,7 @@ public class SSLChannelContextTests extends ESTestCase { context.flushChannel(); verify(sslDriver, times(1)).nonApplicationWrite(); - verify(channel, times(1)).write(sslDriver.getNetworkWriteBuffer()); + verify(rawChannel, times(1)).write(sslDriver.getNetworkWriteBuffer()); } public void testQueuedWriteIsFlushedInFlushCall() throws Exception { @@ -311,7 +332,7 @@ public class SSLChannelContextTests extends ESTestCase { context.flushChannel(); verify(writeOperation).incrementIndex(10); - verify(channel, times(1)).write(sslDriver.getNetworkWriteBuffer()); + verify(rawChannel, times(1)).write(sslDriver.getNetworkWriteBuffer()); verify(selector).executeListener(listener, null); assertFalse(context.hasQueuedWriteOps()); } @@ -330,7 +351,7 @@ public class SSLChannelContextTests extends ESTestCase { context.flushChannel(); verify(writeOperation).incrementIndex(5); - verify(channel, times(1)).write(sslDriver.getNetworkWriteBuffer()); + verify(rawChannel, times(1)).write(sslDriver.getNetworkWriteBuffer()); verify(selector, times(0)).executeListener(listener, null); assertTrue(context.hasQueuedWriteOps()); } @@ -358,7 +379,7 @@ public class SSLChannelContextTests extends ESTestCase { context.flushChannel(); verify(writeOperation1, times(2)).incrementIndex(5); - verify(channel, times(3)).write(sslDriver.getNetworkWriteBuffer()); + verify(rawChannel, times(3)).write(sslDriver.getNetworkWriteBuffer()); verify(selector).executeListener(listener, null); verify(selector, times(0)).executeListener(listener2, null); assertTrue(context.hasQueuedWriteOps()); @@ -375,7 +396,7 @@ public class SSLChannelContextTests extends ESTestCase { when(sslDriver.hasFlushPending()).thenReturn(false, false); when(sslDriver.readyForApplicationWrites()).thenReturn(true); when(sslDriver.applicationWrite(buffers)).thenReturn(5); - when(channel.write(sslDriver.getNetworkWriteBuffer())).thenThrow(exception); + when(rawChannel.write(sslDriver.getNetworkWriteBuffer())).thenThrow(exception); when(writeOperation.isFullyFlushed()).thenReturn(false); expectThrows(IOException.class, () -> context.flushChannel()); @@ -388,7 +409,7 @@ public class SSLChannelContextTests extends ESTestCase { when(sslDriver.hasFlushPending()).thenReturn(true); when(sslDriver.needsNonApplicationWrite()).thenReturn(true); when(sslDriver.readyForApplicationWrites()).thenReturn(false); - when(channel.write(sslDriver.getNetworkWriteBuffer())).thenThrow(new IOException()); + when(rawChannel.write(sslDriver.getNetworkWriteBuffer())).thenThrow(new IOException()); assertFalse(context.selectorShouldClose()); expectThrows(IOException.class, () -> context.flushChannel()); @@ -422,9 +443,17 @@ public class SSLChannelContextTests extends ESTestCase { verify(sslDriver).initiateClose(); } + @SuppressWarnings("unchecked") public void testRegisterInitiatesDriver() throws IOException { - context.channelRegistered(); - verify(sslDriver).init(); + try (Selector realSelector = Selector.open(); + SocketChannel realSocket = SocketChannel.open()) { + realSocket.configureBlocking(false); + when(selector.rawSelector()).thenReturn(realSelector); + when(channel.getRawChannel()).thenReturn(realSocket); + context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readConsumer, channelBuffer); + context.register(); + verify(sslDriver).init(); + } } private Answer getAnswerForBytes(byte[] bytes) {