Support changes in nio channel contexts (elastic/x-pack-elasticsearch#3609)

This is related to elastic/elasticsearch#elastic/x-pack-elasticsearch#28275. It modifies x-pack to
support the changes in channel contexts. Additionally, it simplifies
the SSLChannelContext by relying on some common work between it and
BytesChannelContext.

Original commit: elastic/x-pack-elasticsearch@8a8fcce050
This commit is contained in:
Tim Brooks 2018-01-18 13:06:42 -07:00 committed by GitHub
parent e775e84a7e
commit 685b75da3a
3 changed files with 64 additions and 61 deletions

View File

@ -5,16 +5,18 @@
*/
package org.elasticsearch.xpack.security.transport.nio;
import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.nio.BytesWriteOperation;
import org.elasticsearch.nio.ChannelContext;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.SocketSelector;
import org.elasticsearch.nio.WriteOperation;
import org.elasticsearch.nio.utils.ExceptionsHelper;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
@ -22,22 +24,20 @@ import java.util.function.BiConsumer;
/**
* Provides a TLS/SSL read/write layer over a channel. This context will use a {@link SSLDriver} to handshake
* with the peer channel. Once the handshake is complete, any data from the peer channel will be decrypted
* before being passed to the {@link org.elasticsearch.nio.ChannelContext.ReadConsumer}. Outbound data will
* before being passed to the {@link ReadConsumer}. Outbound data will
* be encrypted before being flushed to the channel.
*/
public final class SSLChannelContext implements ChannelContext {
public final class SSLChannelContext extends SocketChannelContext {
private final NioSocketChannel channel;
private final LinkedList<BytesWriteOperation> queued = new LinkedList<>();
private final SSLDriver sslDriver;
private final ReadConsumer readConsumer;
private final InboundChannelBuffer buffer;
private final AtomicBoolean isClosing = new AtomicBoolean(false);
private boolean peerClosed = false;
private boolean ioException = false;
SSLChannelContext(NioSocketChannel channel, SSLDriver sslDriver, ReadConsumer readConsumer, InboundChannelBuffer buffer) {
this.channel = channel;
SSLChannelContext(NioSocketChannel channel, BiConsumer<NioSocketChannel, Exception> exceptionHandler, SSLDriver sslDriver,
ReadConsumer readConsumer, InboundChannelBuffer buffer) {
super(channel, exceptionHandler);
this.sslDriver = sslDriver;
this.readConsumer = readConsumer;
this.buffer = buffer;
@ -64,7 +64,6 @@ public final class SSLChannelContext implements ChannelContext {
return;
}
// TODO: Eval if we will allow writes from sendMessage
selector.queueWriteInChannelBuffer(writeOperation);
}
@ -80,14 +79,14 @@ public final class SSLChannelContext implements ChannelContext {
@Override
public void flushChannel() throws IOException {
if (ioException) {
if (hasIOException()) {
return;
}
// If there is currently data in the outbound write buffer, flush the buffer.
if (sslDriver.hasFlushPending()) {
internalFlush();
// If the data is not completely flushed, exit. We cannot produce new write data until the
// existing data has been fully flushed.
flushToChannel(sslDriver.getNetworkWriteBuffer());
if (sslDriver.hasFlushPending()) {
return;
}
@ -113,7 +112,7 @@ public final class SSLChannelContext implements ChannelContext {
}
currentOperation.incrementIndex(bytesEncrypted);
// Flush the write buffer to the channel
internalFlush();
flushToChannel(sslDriver.getNetworkWriteBuffer());
} catch (IOException e) {
queued.removeFirst();
channel.getSelector().executeFailedListener(currentOperation.getListener(), e);
@ -128,21 +127,12 @@ public final class SSLChannelContext implements ChannelContext {
sslDriver.nonApplicationWrite();
// If non-application writes were produced, flush the outbound write buffer.
if (sslDriver.hasFlushPending()) {
internalFlush();
flushToChannel(sslDriver.getNetworkWriteBuffer());
}
}
}
}
private int internalFlush() throws IOException {
try {
return channel.write(sslDriver.getNetworkWriteBuffer());
} catch (IOException e) {
ioException = true;
throw e;
}
}
@Override
public boolean hasQueuedWriteOps() {
channel.getSelector().assertOnSelectorThread();
@ -156,18 +146,12 @@ public final class SSLChannelContext implements ChannelContext {
@Override
public int read() throws IOException {
int bytesRead = 0;
if (ioException) {
if (hasIOException()) {
return bytesRead;
}
try {
bytesRead = channel.read(sslDriver.getNetworkReadBuffer());
} catch (IOException e) {
ioException = true;
throw e;
}
if (bytesRead < 0) {
peerClosed = true;
return 0;
bytesRead = readFromChannel(sslDriver.getNetworkReadBuffer());
if (bytesRead == 0) {
return bytesRead;
}
sslDriver.read(buffer);
@ -183,7 +167,7 @@ public final class SSLChannelContext implements ChannelContext {
@Override
public boolean selectorShouldClose() {
return peerClosed || ioException || sslDriver.isClosed();
return isPeerClosed() || hasIOException() || sslDriver.isClosed();
}
@Override
@ -202,14 +186,27 @@ public final class SSLChannelContext implements ChannelContext {
@Override
public void closeFromSelector() throws IOException {
channel.getSelector().assertOnSelectorThread();
// Set to true in order to reject new writes before queuing with selector
isClosing.set(true);
buffer.close();
for (BytesWriteOperation op : queued) {
channel.getSelector().executeFailedListener(op.getListener(), new ClosedChannelException());
if (channel.isOpen()) {
// Set to true in order to reject new writes before queuing with selector
isClosing.set(true);
ArrayList<IOException> closingExceptions = new ArrayList<>(2);
try {
channel.closeFromSelector();
} catch (IOException e) {
closingExceptions.add(e);
}
try {
buffer.close();
for (BytesWriteOperation op : queued) {
channel.getSelector().executeFailedListener(op.getListener(), new ClosedChannelException());
}
queued.clear();
sslDriver.close();
} catch (IOException e) {
closingExceptions.add(e);
}
ExceptionsHelper.rethrowAndSuppress(closingExceptions);
}
queued.clear();
sslDriver.close();
}
private static class CloseNotifyOperation implements WriteOperation {

View File

@ -13,10 +13,11 @@ import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.nio.AcceptingSelector;
import org.elasticsearch.nio.ChannelContext;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.ServerChannelContext;
import org.elasticsearch.nio.SocketSelector;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TcpTransport;
@ -36,6 +37,7 @@ import java.nio.channels.SocketChannel;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import static org.elasticsearch.xpack.security.SecurityField.setting;
@ -125,19 +127,21 @@ public class SecurityNioTransport extends NioTransport {
return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close);
};
ChannelContext.ReadConsumer nioReadConsumer = channelBuffer ->
SocketChannelContext.ReadConsumer nioReadConsumer = channelBuffer ->
consumeNetworkReads(nioChannel, BytesReference.fromByteBuffers(channelBuffer.sliceBuffersTo(channelBuffer.getIndex())));
InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier);
SSLChannelContext context = new SSLChannelContext(nioChannel, sslDriver, nioReadConsumer, buffer);
nioChannel.setContexts(context, SecurityNioTransport.this::exceptionCaught);
BiConsumer<NioSocketChannel, Exception> exceptionHandler = SecurityNioTransport.this::exceptionCaught;
SSLChannelContext context = new SSLChannelContext(nioChannel, exceptionHandler, sslDriver, nioReadConsumer, buffer);
nioChannel.setContext(context);
return nioChannel;
}
@Override
public TcpNioServerSocketChannel createServerChannel(AcceptingSelector selector, ServerSocketChannel channel) throws IOException {
TcpNioServerSocketChannel nioServerChannel = new TcpNioServerSocketChannel(profileName, channel, this, selector);
nioServerChannel.setAcceptContext(SecurityNioTransport.this::acceptChannel);
return nioServerChannel;
TcpNioServerSocketChannel nioChannel = new TcpNioServerSocketChannel(profileName, channel, this, selector);
ServerChannelContext context = new ServerChannelContext(nioChannel, SecurityNioTransport.this::acceptChannel, (c, e) -> {});
nioChannel.setContext(context);
return nioChannel;
}
}
}

View File

@ -6,9 +6,8 @@
package org.elasticsearch.xpack.security.transport.nio;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.nio.BytesChannelContext;
import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.nio.BytesWriteOperation;
import org.elasticsearch.nio.ChannelContext;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.SocketSelector;
@ -21,6 +20,7 @@ import org.mockito.stubbing.Answer;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
@ -35,7 +35,7 @@ import static org.mockito.Mockito.when;
public class SSLChannelContextTests extends ESTestCase {
private ChannelContext.ReadConsumer readConsumer;
private SocketChannelContext.ReadConsumer readConsumer;
private NioSocketChannel channel;
private SSLChannelContext context;
private InboundChannelBuffer channelBuffer;
@ -49,18 +49,15 @@ public class SSLChannelContextTests extends ESTestCase {
@Before
@SuppressWarnings("unchecked")
public void init() {
readConsumer = mock(ChannelContext.ReadConsumer.class);
readConsumer = mock(SocketChannelContext.ReadConsumer.class);
messageLength = randomInt(96) + 20;
selector = mock(SocketSelector.class);
listener = mock(BiConsumer.class);
channel = mock(NioSocketChannel.class);
sslDriver = mock(SSLDriver.class);
Supplier<InboundChannelBuffer.Page> pageSupplier = () ->
new InboundChannelBuffer.Page(ByteBuffer.allocate(BigArrays.BYTE_PAGE_SIZE), () -> {
});
channelBuffer = new InboundChannelBuffer(pageSupplier);
context = new SSLChannelContext(channel, sslDriver, readConsumer, channelBuffer);
channelBuffer = InboundChannelBuffer.allocatingInstance();
context = new SSLChannelContext(channel, mock(BiConsumer.class), sslDriver, readConsumer, channelBuffer);
when(channel.getSelector()).thenReturn(selector);
when(selector.isOnCurrentThread()).thenReturn(true);
@ -145,14 +142,17 @@ public class SSLChannelContextTests extends ESTestCase {
assertTrue(context.selectorShouldClose());
}
@SuppressWarnings("unchecked")
public void testCloseClosesChannelBuffer() throws IOException {
Runnable closer = mock(Runnable.class);
Supplier<InboundChannelBuffer.Page> pageSupplier = () -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14), closer);
AtomicInteger closeCount = new AtomicInteger(0);
Supplier<InboundChannelBuffer.Page> pageSupplier = () -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14),
closeCount::incrementAndGet);
InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier);
buffer.ensureCapacity(1);
BytesChannelContext context = new BytesChannelContext(channel, readConsumer, buffer);
SSLChannelContext context = new SSLChannelContext(channel, mock(BiConsumer.class), sslDriver, readConsumer, buffer);
when(channel.isOpen()).thenReturn(true);
context.closeFromSelector();
verify(closer).run();
assertEquals(1, closeCount.get());
}
public void testWriteOpsClearedOnClose() throws IOException {
@ -164,6 +164,7 @@ public class SSLChannelContextTests extends ESTestCase {
when(sslDriver.readyForApplicationWrites()).thenReturn(true);
assertTrue(context.hasQueuedWriteOps());
when(channel.isOpen()).thenReturn(true);
context.closeFromSelector();
verify(selector).executeFailedListener(same(listener), any(ClosedChannelException.class));
@ -172,6 +173,7 @@ public class SSLChannelContextTests extends ESTestCase {
}
public void testSSLDriverClosedOnClose() throws IOException {
when(channel.isOpen()).thenReturn(true);
context.closeFromSelector();
verify(sslDriver).close();