Support changes in nio channel contexts (elastic/x-pack-elasticsearch#3609)

This is related to elastic/elasticsearch#elastic/x-pack-elasticsearch#28275. It modifies x-pack to
support the changes in channel contexts. Additionally, it simplifies
the SSLChannelContext by relying on some common work between it and
BytesChannelContext.

Original commit: elastic/x-pack-elasticsearch@8a8fcce050
This commit is contained in:
Tim Brooks 2018-01-18 13:06:42 -07:00 committed by GitHub
parent e775e84a7e
commit 685b75da3a
3 changed files with 64 additions and 61 deletions

View File

@ -5,16 +5,18 @@
*/ */
package org.elasticsearch.xpack.security.transport.nio; package org.elasticsearch.xpack.security.transport.nio;
import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.nio.BytesWriteOperation; import org.elasticsearch.nio.BytesWriteOperation;
import org.elasticsearch.nio.ChannelContext;
import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSocketChannel; import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.SocketSelector; import org.elasticsearch.nio.SocketSelector;
import org.elasticsearch.nio.WriteOperation; import org.elasticsearch.nio.WriteOperation;
import org.elasticsearch.nio.utils.ExceptionsHelper;
import java.io.IOException; import java.io.IOException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException; import java.nio.channels.ClosedChannelException;
import java.util.ArrayList;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer; import java.util.function.BiConsumer;
@ -22,22 +24,20 @@ import java.util.function.BiConsumer;
/** /**
* Provides a TLS/SSL read/write layer over a channel. This context will use a {@link SSLDriver} to handshake * Provides a TLS/SSL read/write layer over a channel. This context will use a {@link SSLDriver} to handshake
* with the peer channel. Once the handshake is complete, any data from the peer channel will be decrypted * with the peer channel. Once the handshake is complete, any data from the peer channel will be decrypted
* before being passed to the {@link org.elasticsearch.nio.ChannelContext.ReadConsumer}. Outbound data will * before being passed to the {@link ReadConsumer}. Outbound data will
* be encrypted before being flushed to the channel. * be encrypted before being flushed to the channel.
*/ */
public final class SSLChannelContext implements ChannelContext { public final class SSLChannelContext extends SocketChannelContext {
private final NioSocketChannel channel;
private final LinkedList<BytesWriteOperation> queued = new LinkedList<>(); private final LinkedList<BytesWriteOperation> queued = new LinkedList<>();
private final SSLDriver sslDriver; private final SSLDriver sslDriver;
private final ReadConsumer readConsumer; private final ReadConsumer readConsumer;
private final InboundChannelBuffer buffer; private final InboundChannelBuffer buffer;
private final AtomicBoolean isClosing = new AtomicBoolean(false); private final AtomicBoolean isClosing = new AtomicBoolean(false);
private boolean peerClosed = false;
private boolean ioException = false;
SSLChannelContext(NioSocketChannel channel, SSLDriver sslDriver, ReadConsumer readConsumer, InboundChannelBuffer buffer) { SSLChannelContext(NioSocketChannel channel, BiConsumer<NioSocketChannel, Exception> exceptionHandler, SSLDriver sslDriver,
this.channel = channel; ReadConsumer readConsumer, InboundChannelBuffer buffer) {
super(channel, exceptionHandler);
this.sslDriver = sslDriver; this.sslDriver = sslDriver;
this.readConsumer = readConsumer; this.readConsumer = readConsumer;
this.buffer = buffer; this.buffer = buffer;
@ -64,7 +64,6 @@ public final class SSLChannelContext implements ChannelContext {
return; return;
} }
// TODO: Eval if we will allow writes from sendMessage
selector.queueWriteInChannelBuffer(writeOperation); selector.queueWriteInChannelBuffer(writeOperation);
} }
@ -80,14 +79,14 @@ public final class SSLChannelContext implements ChannelContext {
@Override @Override
public void flushChannel() throws IOException { public void flushChannel() throws IOException {
if (ioException) { if (hasIOException()) {
return; return;
} }
// If there is currently data in the outbound write buffer, flush the buffer. // If there is currently data in the outbound write buffer, flush the buffer.
if (sslDriver.hasFlushPending()) { if (sslDriver.hasFlushPending()) {
internalFlush();
// If the data is not completely flushed, exit. We cannot produce new write data until the // If the data is not completely flushed, exit. We cannot produce new write data until the
// existing data has been fully flushed. // existing data has been fully flushed.
flushToChannel(sslDriver.getNetworkWriteBuffer());
if (sslDriver.hasFlushPending()) { if (sslDriver.hasFlushPending()) {
return; return;
} }
@ -113,7 +112,7 @@ public final class SSLChannelContext implements ChannelContext {
} }
currentOperation.incrementIndex(bytesEncrypted); currentOperation.incrementIndex(bytesEncrypted);
// Flush the write buffer to the channel // Flush the write buffer to the channel
internalFlush(); flushToChannel(sslDriver.getNetworkWriteBuffer());
} catch (IOException e) { } catch (IOException e) {
queued.removeFirst(); queued.removeFirst();
channel.getSelector().executeFailedListener(currentOperation.getListener(), e); channel.getSelector().executeFailedListener(currentOperation.getListener(), e);
@ -128,21 +127,12 @@ public final class SSLChannelContext implements ChannelContext {
sslDriver.nonApplicationWrite(); sslDriver.nonApplicationWrite();
// If non-application writes were produced, flush the outbound write buffer. // If non-application writes were produced, flush the outbound write buffer.
if (sslDriver.hasFlushPending()) { if (sslDriver.hasFlushPending()) {
internalFlush(); flushToChannel(sslDriver.getNetworkWriteBuffer());
} }
} }
} }
} }
private int internalFlush() throws IOException {
try {
return channel.write(sslDriver.getNetworkWriteBuffer());
} catch (IOException e) {
ioException = true;
throw e;
}
}
@Override @Override
public boolean hasQueuedWriteOps() { public boolean hasQueuedWriteOps() {
channel.getSelector().assertOnSelectorThread(); channel.getSelector().assertOnSelectorThread();
@ -156,18 +146,12 @@ public final class SSLChannelContext implements ChannelContext {
@Override @Override
public int read() throws IOException { public int read() throws IOException {
int bytesRead = 0; int bytesRead = 0;
if (ioException) { if (hasIOException()) {
return bytesRead; return bytesRead;
} }
try { bytesRead = readFromChannel(sslDriver.getNetworkReadBuffer());
bytesRead = channel.read(sslDriver.getNetworkReadBuffer()); if (bytesRead == 0) {
} catch (IOException e) { return bytesRead;
ioException = true;
throw e;
}
if (bytesRead < 0) {
peerClosed = true;
return 0;
} }
sslDriver.read(buffer); sslDriver.read(buffer);
@ -183,7 +167,7 @@ public final class SSLChannelContext implements ChannelContext {
@Override @Override
public boolean selectorShouldClose() { public boolean selectorShouldClose() {
return peerClosed || ioException || sslDriver.isClosed(); return isPeerClosed() || hasIOException() || sslDriver.isClosed();
} }
@Override @Override
@ -202,14 +186,27 @@ public final class SSLChannelContext implements ChannelContext {
@Override @Override
public void closeFromSelector() throws IOException { public void closeFromSelector() throws IOException {
channel.getSelector().assertOnSelectorThread(); channel.getSelector().assertOnSelectorThread();
if (channel.isOpen()) {
// Set to true in order to reject new writes before queuing with selector // Set to true in order to reject new writes before queuing with selector
isClosing.set(true); isClosing.set(true);
ArrayList<IOException> closingExceptions = new ArrayList<>(2);
try {
channel.closeFromSelector();
} catch (IOException e) {
closingExceptions.add(e);
}
try {
buffer.close(); buffer.close();
for (BytesWriteOperation op : queued) { for (BytesWriteOperation op : queued) {
channel.getSelector().executeFailedListener(op.getListener(), new ClosedChannelException()); channel.getSelector().executeFailedListener(op.getListener(), new ClosedChannelException());
} }
queued.clear(); queued.clear();
sslDriver.close(); sslDriver.close();
} catch (IOException e) {
closingExceptions.add(e);
}
ExceptionsHelper.rethrowAndSuppress(closingExceptions);
}
} }
private static class CloseNotifyOperation implements WriteOperation { private static class CloseNotifyOperation implements WriteOperation {

View File

@ -13,10 +13,11 @@ import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.PageCacheRecycler; import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.nio.AcceptingSelector; import org.elasticsearch.nio.AcceptingSelector;
import org.elasticsearch.nio.ChannelContext;
import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSocketChannel; import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.ServerChannelContext;
import org.elasticsearch.nio.SocketSelector; import org.elasticsearch.nio.SocketSelector;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TcpTransport; import org.elasticsearch.transport.TcpTransport;
@ -36,6 +37,7 @@ import java.nio.channels.SocketChannel;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.function.BiConsumer;
import java.util.function.Supplier; import java.util.function.Supplier;
import static org.elasticsearch.xpack.security.SecurityField.setting; import static org.elasticsearch.xpack.security.SecurityField.setting;
@ -125,19 +127,21 @@ public class SecurityNioTransport extends NioTransport {
return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close); return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close);
}; };
ChannelContext.ReadConsumer nioReadConsumer = channelBuffer -> SocketChannelContext.ReadConsumer nioReadConsumer = channelBuffer ->
consumeNetworkReads(nioChannel, BytesReference.fromByteBuffers(channelBuffer.sliceBuffersTo(channelBuffer.getIndex()))); consumeNetworkReads(nioChannel, BytesReference.fromByteBuffers(channelBuffer.sliceBuffersTo(channelBuffer.getIndex())));
InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier); InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier);
SSLChannelContext context = new SSLChannelContext(nioChannel, sslDriver, nioReadConsumer, buffer); BiConsumer<NioSocketChannel, Exception> exceptionHandler = SecurityNioTransport.this::exceptionCaught;
nioChannel.setContexts(context, SecurityNioTransport.this::exceptionCaught); SSLChannelContext context = new SSLChannelContext(nioChannel, exceptionHandler, sslDriver, nioReadConsumer, buffer);
nioChannel.setContext(context);
return nioChannel; return nioChannel;
} }
@Override @Override
public TcpNioServerSocketChannel createServerChannel(AcceptingSelector selector, ServerSocketChannel channel) throws IOException { public TcpNioServerSocketChannel createServerChannel(AcceptingSelector selector, ServerSocketChannel channel) throws IOException {
TcpNioServerSocketChannel nioServerChannel = new TcpNioServerSocketChannel(profileName, channel, this, selector); TcpNioServerSocketChannel nioChannel = new TcpNioServerSocketChannel(profileName, channel, this, selector);
nioServerChannel.setAcceptContext(SecurityNioTransport.this::acceptChannel); ServerChannelContext context = new ServerChannelContext(nioChannel, SecurityNioTransport.this::acceptChannel, (c, e) -> {});
return nioServerChannel; nioChannel.setContext(context);
return nioChannel;
} }
} }
} }

View File

@ -6,9 +6,8 @@
package org.elasticsearch.xpack.security.transport.nio; package org.elasticsearch.xpack.security.transport.nio;
import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.nio.BytesChannelContext; import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.nio.BytesWriteOperation; import org.elasticsearch.nio.BytesWriteOperation;
import org.elasticsearch.nio.ChannelContext;
import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSocketChannel; import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.SocketSelector; import org.elasticsearch.nio.SocketSelector;
@ -21,6 +20,7 @@ import org.mockito.stubbing.Answer;
import java.io.IOException; import java.io.IOException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException; import java.nio.channels.ClosedChannelException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer; import java.util.function.BiConsumer;
import java.util.function.Supplier; import java.util.function.Supplier;
@ -35,7 +35,7 @@ import static org.mockito.Mockito.when;
public class SSLChannelContextTests extends ESTestCase { public class SSLChannelContextTests extends ESTestCase {
private ChannelContext.ReadConsumer readConsumer; private SocketChannelContext.ReadConsumer readConsumer;
private NioSocketChannel channel; private NioSocketChannel channel;
private SSLChannelContext context; private SSLChannelContext context;
private InboundChannelBuffer channelBuffer; private InboundChannelBuffer channelBuffer;
@ -49,18 +49,15 @@ public class SSLChannelContextTests extends ESTestCase {
@Before @Before
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public void init() { public void init() {
readConsumer = mock(ChannelContext.ReadConsumer.class); readConsumer = mock(SocketChannelContext.ReadConsumer.class);
messageLength = randomInt(96) + 20; messageLength = randomInt(96) + 20;
selector = mock(SocketSelector.class); selector = mock(SocketSelector.class);
listener = mock(BiConsumer.class); listener = mock(BiConsumer.class);
channel = mock(NioSocketChannel.class); channel = mock(NioSocketChannel.class);
sslDriver = mock(SSLDriver.class); sslDriver = mock(SSLDriver.class);
Supplier<InboundChannelBuffer.Page> pageSupplier = () -> channelBuffer = InboundChannelBuffer.allocatingInstance();
new InboundChannelBuffer.Page(ByteBuffer.allocate(BigArrays.BYTE_PAGE_SIZE), () -> { context = new SSLChannelContext(channel, mock(BiConsumer.class), sslDriver, readConsumer, channelBuffer);
});
channelBuffer = new InboundChannelBuffer(pageSupplier);
context = new SSLChannelContext(channel, sslDriver, readConsumer, channelBuffer);
when(channel.getSelector()).thenReturn(selector); when(channel.getSelector()).thenReturn(selector);
when(selector.isOnCurrentThread()).thenReturn(true); when(selector.isOnCurrentThread()).thenReturn(true);
@ -145,14 +142,17 @@ public class SSLChannelContextTests extends ESTestCase {
assertTrue(context.selectorShouldClose()); assertTrue(context.selectorShouldClose());
} }
@SuppressWarnings("unchecked")
public void testCloseClosesChannelBuffer() throws IOException { public void testCloseClosesChannelBuffer() throws IOException {
Runnable closer = mock(Runnable.class); AtomicInteger closeCount = new AtomicInteger(0);
Supplier<InboundChannelBuffer.Page> pageSupplier = () -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14), closer); Supplier<InboundChannelBuffer.Page> pageSupplier = () -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14),
closeCount::incrementAndGet);
InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier); InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier);
buffer.ensureCapacity(1); buffer.ensureCapacity(1);
BytesChannelContext context = new BytesChannelContext(channel, readConsumer, buffer); SSLChannelContext context = new SSLChannelContext(channel, mock(BiConsumer.class), sslDriver, readConsumer, buffer);
when(channel.isOpen()).thenReturn(true);
context.closeFromSelector(); context.closeFromSelector();
verify(closer).run(); assertEquals(1, closeCount.get());
} }
public void testWriteOpsClearedOnClose() throws IOException { public void testWriteOpsClearedOnClose() throws IOException {
@ -164,6 +164,7 @@ public class SSLChannelContextTests extends ESTestCase {
when(sslDriver.readyForApplicationWrites()).thenReturn(true); when(sslDriver.readyForApplicationWrites()).thenReturn(true);
assertTrue(context.hasQueuedWriteOps()); assertTrue(context.hasQueuedWriteOps());
when(channel.isOpen()).thenReturn(true);
context.closeFromSelector(); context.closeFromSelector();
verify(selector).executeFailedListener(same(listener), any(ClosedChannelException.class)); verify(selector).executeFailedListener(same(listener), any(ClosedChannelException.class));
@ -172,6 +173,7 @@ public class SSLChannelContextTests extends ESTestCase {
} }
public void testSSLDriverClosedOnClose() throws IOException { public void testSSLDriverClosedOnClose() throws IOException {
when(channel.isOpen()).thenReturn(true);
context.closeFromSelector(); context.closeFromSelector();
verify(sslDriver).close(); verify(sslDriver).close();