From 685b75da3a8fe1ccac7ab44927056bc136a3af72 Mon Sep 17 00:00:00 2001 From: Tim Brooks Date: Thu, 18 Jan 2018 13:06:42 -0700 Subject: [PATCH] Support changes in nio channel contexts (elastic/x-pack-elasticsearch#3609) This is related to elastic/elasticsearch#elastic/x-pack-elasticsearch#28275. It modifies x-pack to support the changes in channel contexts. Additionally, it simplifies the SSLChannelContext by relying on some common work between it and BytesChannelContext. Original commit: elastic/x-pack-elasticsearch@8a8fcce050940e2b1f86dbbd20e7f150f488ebe6 --- .../transport/nio/SSLChannelContext.java | 77 +++++++++---------- .../transport/nio/SecurityNioTransport.java | 20 +++-- .../transport/nio/SSLChannelContextTests.java | 28 +++---- 3 files changed, 64 insertions(+), 61 deletions(-) diff --git a/plugin/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java b/plugin/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java index 3ae21ef8f49..28afa151d80 100644 --- a/plugin/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java +++ b/plugin/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java @@ -5,16 +5,18 @@ */ package org.elasticsearch.xpack.security.transport.nio; +import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.nio.BytesWriteOperation; -import org.elasticsearch.nio.ChannelContext; import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioSocketChannel; import org.elasticsearch.nio.SocketSelector; import org.elasticsearch.nio.WriteOperation; +import org.elasticsearch.nio.utils.ExceptionsHelper; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.ClosedChannelException; +import java.util.ArrayList; import java.util.LinkedList; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.BiConsumer; @@ -22,22 +24,20 @@ import java.util.function.BiConsumer; /** * Provides a TLS/SSL read/write layer over a channel. This context will use a {@link SSLDriver} to handshake * with the peer channel. Once the handshake is complete, any data from the peer channel will be decrypted - * before being passed to the {@link org.elasticsearch.nio.ChannelContext.ReadConsumer}. Outbound data will + * before being passed to the {@link ReadConsumer}. Outbound data will * be encrypted before being flushed to the channel. */ -public final class SSLChannelContext implements ChannelContext { +public final class SSLChannelContext extends SocketChannelContext { - private final NioSocketChannel channel; private final LinkedList queued = new LinkedList<>(); private final SSLDriver sslDriver; private final ReadConsumer readConsumer; private final InboundChannelBuffer buffer; private final AtomicBoolean isClosing = new AtomicBoolean(false); - private boolean peerClosed = false; - private boolean ioException = false; - SSLChannelContext(NioSocketChannel channel, SSLDriver sslDriver, ReadConsumer readConsumer, InboundChannelBuffer buffer) { - this.channel = channel; + SSLChannelContext(NioSocketChannel channel, BiConsumer exceptionHandler, SSLDriver sslDriver, + ReadConsumer readConsumer, InboundChannelBuffer buffer) { + super(channel, exceptionHandler); this.sslDriver = sslDriver; this.readConsumer = readConsumer; this.buffer = buffer; @@ -64,7 +64,6 @@ public final class SSLChannelContext implements ChannelContext { return; } - // TODO: Eval if we will allow writes from sendMessage selector.queueWriteInChannelBuffer(writeOperation); } @@ -80,14 +79,14 @@ public final class SSLChannelContext implements ChannelContext { @Override public void flushChannel() throws IOException { - if (ioException) { + if (hasIOException()) { return; } // If there is currently data in the outbound write buffer, flush the buffer. if (sslDriver.hasFlushPending()) { - internalFlush(); // If the data is not completely flushed, exit. We cannot produce new write data until the // existing data has been fully flushed. + flushToChannel(sslDriver.getNetworkWriteBuffer()); if (sslDriver.hasFlushPending()) { return; } @@ -113,7 +112,7 @@ public final class SSLChannelContext implements ChannelContext { } currentOperation.incrementIndex(bytesEncrypted); // Flush the write buffer to the channel - internalFlush(); + flushToChannel(sslDriver.getNetworkWriteBuffer()); } catch (IOException e) { queued.removeFirst(); channel.getSelector().executeFailedListener(currentOperation.getListener(), e); @@ -128,21 +127,12 @@ public final class SSLChannelContext implements ChannelContext { sslDriver.nonApplicationWrite(); // If non-application writes were produced, flush the outbound write buffer. if (sslDriver.hasFlushPending()) { - internalFlush(); + flushToChannel(sslDriver.getNetworkWriteBuffer()); } } } } - private int internalFlush() throws IOException { - try { - return channel.write(sslDriver.getNetworkWriteBuffer()); - } catch (IOException e) { - ioException = true; - throw e; - } - } - @Override public boolean hasQueuedWriteOps() { channel.getSelector().assertOnSelectorThread(); @@ -156,18 +146,12 @@ public final class SSLChannelContext implements ChannelContext { @Override public int read() throws IOException { int bytesRead = 0; - if (ioException) { + if (hasIOException()) { return bytesRead; } - try { - bytesRead = channel.read(sslDriver.getNetworkReadBuffer()); - } catch (IOException e) { - ioException = true; - throw e; - } - if (bytesRead < 0) { - peerClosed = true; - return 0; + bytesRead = readFromChannel(sslDriver.getNetworkReadBuffer()); + if (bytesRead == 0) { + return bytesRead; } sslDriver.read(buffer); @@ -183,7 +167,7 @@ public final class SSLChannelContext implements ChannelContext { @Override public boolean selectorShouldClose() { - return peerClosed || ioException || sslDriver.isClosed(); + return isPeerClosed() || hasIOException() || sslDriver.isClosed(); } @Override @@ -202,14 +186,27 @@ public final class SSLChannelContext implements ChannelContext { @Override public void closeFromSelector() throws IOException { channel.getSelector().assertOnSelectorThread(); - // Set to true in order to reject new writes before queuing with selector - isClosing.set(true); - buffer.close(); - for (BytesWriteOperation op : queued) { - channel.getSelector().executeFailedListener(op.getListener(), new ClosedChannelException()); + if (channel.isOpen()) { + // Set to true in order to reject new writes before queuing with selector + isClosing.set(true); + ArrayList closingExceptions = new ArrayList<>(2); + try { + channel.closeFromSelector(); + } catch (IOException e) { + closingExceptions.add(e); + } + try { + buffer.close(); + for (BytesWriteOperation op : queued) { + channel.getSelector().executeFailedListener(op.getListener(), new ClosedChannelException()); + } + queued.clear(); + sslDriver.close(); + } catch (IOException e) { + closingExceptions.add(e); + } + ExceptionsHelper.rethrowAndSuppress(closingExceptions); } - queued.clear(); - sslDriver.close(); } private static class CloseNotifyOperation implements WriteOperation { diff --git a/plugin/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java b/plugin/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java index a03b1171f17..dfa167770c4 100644 --- a/plugin/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java +++ b/plugin/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java @@ -13,10 +13,11 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.PageCacheRecycler; import org.elasticsearch.indices.breaker.CircuitBreakerService; +import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.nio.AcceptingSelector; -import org.elasticsearch.nio.ChannelContext; import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioSocketChannel; +import org.elasticsearch.nio.ServerChannelContext; import org.elasticsearch.nio.SocketSelector; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TcpTransport; @@ -36,6 +37,7 @@ import java.nio.channels.SocketChannel; import java.util.Collections; import java.util.HashMap; import java.util.Map; +import java.util.function.BiConsumer; import java.util.function.Supplier; import static org.elasticsearch.xpack.security.SecurityField.setting; @@ -125,19 +127,21 @@ public class SecurityNioTransport extends NioTransport { return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close); }; - ChannelContext.ReadConsumer nioReadConsumer = channelBuffer -> + SocketChannelContext.ReadConsumer nioReadConsumer = channelBuffer -> consumeNetworkReads(nioChannel, BytesReference.fromByteBuffers(channelBuffer.sliceBuffersTo(channelBuffer.getIndex()))); InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier); - SSLChannelContext context = new SSLChannelContext(nioChannel, sslDriver, nioReadConsumer, buffer); - nioChannel.setContexts(context, SecurityNioTransport.this::exceptionCaught); + BiConsumer exceptionHandler = SecurityNioTransport.this::exceptionCaught; + SSLChannelContext context = new SSLChannelContext(nioChannel, exceptionHandler, sslDriver, nioReadConsumer, buffer); + nioChannel.setContext(context); return nioChannel; } @Override public TcpNioServerSocketChannel createServerChannel(AcceptingSelector selector, ServerSocketChannel channel) throws IOException { - TcpNioServerSocketChannel nioServerChannel = new TcpNioServerSocketChannel(profileName, channel, this, selector); - nioServerChannel.setAcceptContext(SecurityNioTransport.this::acceptChannel); - return nioServerChannel; + TcpNioServerSocketChannel nioChannel = new TcpNioServerSocketChannel(profileName, channel, this, selector); + ServerChannelContext context = new ServerChannelContext(nioChannel, SecurityNioTransport.this::acceptChannel, (c, e) -> {}); + nioChannel.setContext(context); + return nioChannel; } } -} \ No newline at end of file +} diff --git a/plugin/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java b/plugin/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java index 7e5a89e747c..a6d702df89a 100644 --- a/plugin/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java +++ b/plugin/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java @@ -6,9 +6,8 @@ package org.elasticsearch.xpack.security.transport.nio; import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.nio.BytesChannelContext; +import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.nio.BytesWriteOperation; -import org.elasticsearch.nio.ChannelContext; import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioSocketChannel; import org.elasticsearch.nio.SocketSelector; @@ -21,6 +20,7 @@ import org.mockito.stubbing.Answer; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.ClosedChannelException; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiConsumer; import java.util.function.Supplier; @@ -35,7 +35,7 @@ import static org.mockito.Mockito.when; public class SSLChannelContextTests extends ESTestCase { - private ChannelContext.ReadConsumer readConsumer; + private SocketChannelContext.ReadConsumer readConsumer; private NioSocketChannel channel; private SSLChannelContext context; private InboundChannelBuffer channelBuffer; @@ -49,18 +49,15 @@ public class SSLChannelContextTests extends ESTestCase { @Before @SuppressWarnings("unchecked") public void init() { - readConsumer = mock(ChannelContext.ReadConsumer.class); + readConsumer = mock(SocketChannelContext.ReadConsumer.class); messageLength = randomInt(96) + 20; selector = mock(SocketSelector.class); listener = mock(BiConsumer.class); channel = mock(NioSocketChannel.class); sslDriver = mock(SSLDriver.class); - Supplier pageSupplier = () -> - new InboundChannelBuffer.Page(ByteBuffer.allocate(BigArrays.BYTE_PAGE_SIZE), () -> { - }); - channelBuffer = new InboundChannelBuffer(pageSupplier); - context = new SSLChannelContext(channel, sslDriver, readConsumer, channelBuffer); + channelBuffer = InboundChannelBuffer.allocatingInstance(); + context = new SSLChannelContext(channel, mock(BiConsumer.class), sslDriver, readConsumer, channelBuffer); when(channel.getSelector()).thenReturn(selector); when(selector.isOnCurrentThread()).thenReturn(true); @@ -145,14 +142,17 @@ public class SSLChannelContextTests extends ESTestCase { assertTrue(context.selectorShouldClose()); } + @SuppressWarnings("unchecked") public void testCloseClosesChannelBuffer() throws IOException { - Runnable closer = mock(Runnable.class); - Supplier pageSupplier = () -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14), closer); + AtomicInteger closeCount = new AtomicInteger(0); + Supplier pageSupplier = () -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14), + closeCount::incrementAndGet); InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier); buffer.ensureCapacity(1); - BytesChannelContext context = new BytesChannelContext(channel, readConsumer, buffer); + SSLChannelContext context = new SSLChannelContext(channel, mock(BiConsumer.class), sslDriver, readConsumer, buffer); + when(channel.isOpen()).thenReturn(true); context.closeFromSelector(); - verify(closer).run(); + assertEquals(1, closeCount.get()); } public void testWriteOpsClearedOnClose() throws IOException { @@ -164,6 +164,7 @@ public class SSLChannelContextTests extends ESTestCase { when(sslDriver.readyForApplicationWrites()).thenReturn(true); assertTrue(context.hasQueuedWriteOps()); + when(channel.isOpen()).thenReturn(true); context.closeFromSelector(); verify(selector).executeFailedListener(same(listener), any(ClosedChannelException.class)); @@ -172,6 +173,7 @@ public class SSLChannelContextTests extends ESTestCase { } public void testSSLDriverClosedOnClose() throws IOException { + when(channel.isOpen()).thenReturn(true); context.closeFromSelector(); verify(sslDriver).close();