Support changes in nio channel contexts (elastic/x-pack-elasticsearch#3609)
This is related to elastic/elasticsearch#elastic/x-pack-elasticsearch#28275. It modifies x-pack to support the changes in channel contexts. Additionally, it simplifies the SSLChannelContext by relying on some common work between it and BytesChannelContext. Original commit: elastic/x-pack-elasticsearch@8a8fcce050
This commit is contained in:
parent
e775e84a7e
commit
685b75da3a
|
@ -5,16 +5,18 @@
|
||||||
*/
|
*/
|
||||||
package org.elasticsearch.xpack.security.transport.nio;
|
package org.elasticsearch.xpack.security.transport.nio;
|
||||||
|
|
||||||
|
import org.elasticsearch.nio.SocketChannelContext;
|
||||||
import org.elasticsearch.nio.BytesWriteOperation;
|
import org.elasticsearch.nio.BytesWriteOperation;
|
||||||
import org.elasticsearch.nio.ChannelContext;
|
|
||||||
import org.elasticsearch.nio.InboundChannelBuffer;
|
import org.elasticsearch.nio.InboundChannelBuffer;
|
||||||
import org.elasticsearch.nio.NioSocketChannel;
|
import org.elasticsearch.nio.NioSocketChannel;
|
||||||
import org.elasticsearch.nio.SocketSelector;
|
import org.elasticsearch.nio.SocketSelector;
|
||||||
import org.elasticsearch.nio.WriteOperation;
|
import org.elasticsearch.nio.WriteOperation;
|
||||||
|
import org.elasticsearch.nio.utils.ExceptionsHelper;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.nio.channels.ClosedChannelException;
|
import java.nio.channels.ClosedChannelException;
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.LinkedList;
|
import java.util.LinkedList;
|
||||||
import java.util.concurrent.atomic.AtomicBoolean;
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
import java.util.function.BiConsumer;
|
import java.util.function.BiConsumer;
|
||||||
|
@ -22,22 +24,20 @@ import java.util.function.BiConsumer;
|
||||||
/**
|
/**
|
||||||
* Provides a TLS/SSL read/write layer over a channel. This context will use a {@link SSLDriver} to handshake
|
* Provides a TLS/SSL read/write layer over a channel. This context will use a {@link SSLDriver} to handshake
|
||||||
* with the peer channel. Once the handshake is complete, any data from the peer channel will be decrypted
|
* with the peer channel. Once the handshake is complete, any data from the peer channel will be decrypted
|
||||||
* before being passed to the {@link org.elasticsearch.nio.ChannelContext.ReadConsumer}. Outbound data will
|
* before being passed to the {@link ReadConsumer}. Outbound data will
|
||||||
* be encrypted before being flushed to the channel.
|
* be encrypted before being flushed to the channel.
|
||||||
*/
|
*/
|
||||||
public final class SSLChannelContext implements ChannelContext {
|
public final class SSLChannelContext extends SocketChannelContext {
|
||||||
|
|
||||||
private final NioSocketChannel channel;
|
|
||||||
private final LinkedList<BytesWriteOperation> queued = new LinkedList<>();
|
private final LinkedList<BytesWriteOperation> queued = new LinkedList<>();
|
||||||
private final SSLDriver sslDriver;
|
private final SSLDriver sslDriver;
|
||||||
private final ReadConsumer readConsumer;
|
private final ReadConsumer readConsumer;
|
||||||
private final InboundChannelBuffer buffer;
|
private final InboundChannelBuffer buffer;
|
||||||
private final AtomicBoolean isClosing = new AtomicBoolean(false);
|
private final AtomicBoolean isClosing = new AtomicBoolean(false);
|
||||||
private boolean peerClosed = false;
|
|
||||||
private boolean ioException = false;
|
|
||||||
|
|
||||||
SSLChannelContext(NioSocketChannel channel, SSLDriver sslDriver, ReadConsumer readConsumer, InboundChannelBuffer buffer) {
|
SSLChannelContext(NioSocketChannel channel, BiConsumer<NioSocketChannel, Exception> exceptionHandler, SSLDriver sslDriver,
|
||||||
this.channel = channel;
|
ReadConsumer readConsumer, InboundChannelBuffer buffer) {
|
||||||
|
super(channel, exceptionHandler);
|
||||||
this.sslDriver = sslDriver;
|
this.sslDriver = sslDriver;
|
||||||
this.readConsumer = readConsumer;
|
this.readConsumer = readConsumer;
|
||||||
this.buffer = buffer;
|
this.buffer = buffer;
|
||||||
|
@ -64,7 +64,6 @@ public final class SSLChannelContext implements ChannelContext {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Eval if we will allow writes from sendMessage
|
|
||||||
selector.queueWriteInChannelBuffer(writeOperation);
|
selector.queueWriteInChannelBuffer(writeOperation);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -80,14 +79,14 @@ public final class SSLChannelContext implements ChannelContext {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void flushChannel() throws IOException {
|
public void flushChannel() throws IOException {
|
||||||
if (ioException) {
|
if (hasIOException()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// If there is currently data in the outbound write buffer, flush the buffer.
|
// If there is currently data in the outbound write buffer, flush the buffer.
|
||||||
if (sslDriver.hasFlushPending()) {
|
if (sslDriver.hasFlushPending()) {
|
||||||
internalFlush();
|
|
||||||
// If the data is not completely flushed, exit. We cannot produce new write data until the
|
// If the data is not completely flushed, exit. We cannot produce new write data until the
|
||||||
// existing data has been fully flushed.
|
// existing data has been fully flushed.
|
||||||
|
flushToChannel(sslDriver.getNetworkWriteBuffer());
|
||||||
if (sslDriver.hasFlushPending()) {
|
if (sslDriver.hasFlushPending()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -113,7 +112,7 @@ public final class SSLChannelContext implements ChannelContext {
|
||||||
}
|
}
|
||||||
currentOperation.incrementIndex(bytesEncrypted);
|
currentOperation.incrementIndex(bytesEncrypted);
|
||||||
// Flush the write buffer to the channel
|
// Flush the write buffer to the channel
|
||||||
internalFlush();
|
flushToChannel(sslDriver.getNetworkWriteBuffer());
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
queued.removeFirst();
|
queued.removeFirst();
|
||||||
channel.getSelector().executeFailedListener(currentOperation.getListener(), e);
|
channel.getSelector().executeFailedListener(currentOperation.getListener(), e);
|
||||||
|
@ -128,21 +127,12 @@ public final class SSLChannelContext implements ChannelContext {
|
||||||
sslDriver.nonApplicationWrite();
|
sslDriver.nonApplicationWrite();
|
||||||
// If non-application writes were produced, flush the outbound write buffer.
|
// If non-application writes were produced, flush the outbound write buffer.
|
||||||
if (sslDriver.hasFlushPending()) {
|
if (sslDriver.hasFlushPending()) {
|
||||||
internalFlush();
|
flushToChannel(sslDriver.getNetworkWriteBuffer());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private int internalFlush() throws IOException {
|
|
||||||
try {
|
|
||||||
return channel.write(sslDriver.getNetworkWriteBuffer());
|
|
||||||
} catch (IOException e) {
|
|
||||||
ioException = true;
|
|
||||||
throw e;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean hasQueuedWriteOps() {
|
public boolean hasQueuedWriteOps() {
|
||||||
channel.getSelector().assertOnSelectorThread();
|
channel.getSelector().assertOnSelectorThread();
|
||||||
|
@ -156,18 +146,12 @@ public final class SSLChannelContext implements ChannelContext {
|
||||||
@Override
|
@Override
|
||||||
public int read() throws IOException {
|
public int read() throws IOException {
|
||||||
int bytesRead = 0;
|
int bytesRead = 0;
|
||||||
if (ioException) {
|
if (hasIOException()) {
|
||||||
return bytesRead;
|
return bytesRead;
|
||||||
}
|
}
|
||||||
try {
|
bytesRead = readFromChannel(sslDriver.getNetworkReadBuffer());
|
||||||
bytesRead = channel.read(sslDriver.getNetworkReadBuffer());
|
if (bytesRead == 0) {
|
||||||
} catch (IOException e) {
|
return bytesRead;
|
||||||
ioException = true;
|
|
||||||
throw e;
|
|
||||||
}
|
|
||||||
if (bytesRead < 0) {
|
|
||||||
peerClosed = true;
|
|
||||||
return 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
sslDriver.read(buffer);
|
sslDriver.read(buffer);
|
||||||
|
@ -183,7 +167,7 @@ public final class SSLChannelContext implements ChannelContext {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean selectorShouldClose() {
|
public boolean selectorShouldClose() {
|
||||||
return peerClosed || ioException || sslDriver.isClosed();
|
return isPeerClosed() || hasIOException() || sslDriver.isClosed();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -202,14 +186,27 @@ public final class SSLChannelContext implements ChannelContext {
|
||||||
@Override
|
@Override
|
||||||
public void closeFromSelector() throws IOException {
|
public void closeFromSelector() throws IOException {
|
||||||
channel.getSelector().assertOnSelectorThread();
|
channel.getSelector().assertOnSelectorThread();
|
||||||
|
if (channel.isOpen()) {
|
||||||
// Set to true in order to reject new writes before queuing with selector
|
// Set to true in order to reject new writes before queuing with selector
|
||||||
isClosing.set(true);
|
isClosing.set(true);
|
||||||
|
ArrayList<IOException> closingExceptions = new ArrayList<>(2);
|
||||||
|
try {
|
||||||
|
channel.closeFromSelector();
|
||||||
|
} catch (IOException e) {
|
||||||
|
closingExceptions.add(e);
|
||||||
|
}
|
||||||
|
try {
|
||||||
buffer.close();
|
buffer.close();
|
||||||
for (BytesWriteOperation op : queued) {
|
for (BytesWriteOperation op : queued) {
|
||||||
channel.getSelector().executeFailedListener(op.getListener(), new ClosedChannelException());
|
channel.getSelector().executeFailedListener(op.getListener(), new ClosedChannelException());
|
||||||
}
|
}
|
||||||
queued.clear();
|
queued.clear();
|
||||||
sslDriver.close();
|
sslDriver.close();
|
||||||
|
} catch (IOException e) {
|
||||||
|
closingExceptions.add(e);
|
||||||
|
}
|
||||||
|
ExceptionsHelper.rethrowAndSuppress(closingExceptions);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class CloseNotifyOperation implements WriteOperation {
|
private static class CloseNotifyOperation implements WriteOperation {
|
||||||
|
|
|
@ -13,10 +13,11 @@ import org.elasticsearch.common.settings.Settings;
|
||||||
import org.elasticsearch.common.util.BigArrays;
|
import org.elasticsearch.common.util.BigArrays;
|
||||||
import org.elasticsearch.common.util.PageCacheRecycler;
|
import org.elasticsearch.common.util.PageCacheRecycler;
|
||||||
import org.elasticsearch.indices.breaker.CircuitBreakerService;
|
import org.elasticsearch.indices.breaker.CircuitBreakerService;
|
||||||
|
import org.elasticsearch.nio.SocketChannelContext;
|
||||||
import org.elasticsearch.nio.AcceptingSelector;
|
import org.elasticsearch.nio.AcceptingSelector;
|
||||||
import org.elasticsearch.nio.ChannelContext;
|
|
||||||
import org.elasticsearch.nio.InboundChannelBuffer;
|
import org.elasticsearch.nio.InboundChannelBuffer;
|
||||||
import org.elasticsearch.nio.NioSocketChannel;
|
import org.elasticsearch.nio.NioSocketChannel;
|
||||||
|
import org.elasticsearch.nio.ServerChannelContext;
|
||||||
import org.elasticsearch.nio.SocketSelector;
|
import org.elasticsearch.nio.SocketSelector;
|
||||||
import org.elasticsearch.threadpool.ThreadPool;
|
import org.elasticsearch.threadpool.ThreadPool;
|
||||||
import org.elasticsearch.transport.TcpTransport;
|
import org.elasticsearch.transport.TcpTransport;
|
||||||
|
@ -36,6 +37,7 @@ import java.nio.channels.SocketChannel;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.function.BiConsumer;
|
||||||
import java.util.function.Supplier;
|
import java.util.function.Supplier;
|
||||||
|
|
||||||
import static org.elasticsearch.xpack.security.SecurityField.setting;
|
import static org.elasticsearch.xpack.security.SecurityField.setting;
|
||||||
|
@ -125,19 +127,21 @@ public class SecurityNioTransport extends NioTransport {
|
||||||
return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close);
|
return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close);
|
||||||
};
|
};
|
||||||
|
|
||||||
ChannelContext.ReadConsumer nioReadConsumer = channelBuffer ->
|
SocketChannelContext.ReadConsumer nioReadConsumer = channelBuffer ->
|
||||||
consumeNetworkReads(nioChannel, BytesReference.fromByteBuffers(channelBuffer.sliceBuffersTo(channelBuffer.getIndex())));
|
consumeNetworkReads(nioChannel, BytesReference.fromByteBuffers(channelBuffer.sliceBuffersTo(channelBuffer.getIndex())));
|
||||||
InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier);
|
InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier);
|
||||||
SSLChannelContext context = new SSLChannelContext(nioChannel, sslDriver, nioReadConsumer, buffer);
|
BiConsumer<NioSocketChannel, Exception> exceptionHandler = SecurityNioTransport.this::exceptionCaught;
|
||||||
nioChannel.setContexts(context, SecurityNioTransport.this::exceptionCaught);
|
SSLChannelContext context = new SSLChannelContext(nioChannel, exceptionHandler, sslDriver, nioReadConsumer, buffer);
|
||||||
|
nioChannel.setContext(context);
|
||||||
return nioChannel;
|
return nioChannel;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public TcpNioServerSocketChannel createServerChannel(AcceptingSelector selector, ServerSocketChannel channel) throws IOException {
|
public TcpNioServerSocketChannel createServerChannel(AcceptingSelector selector, ServerSocketChannel channel) throws IOException {
|
||||||
TcpNioServerSocketChannel nioServerChannel = new TcpNioServerSocketChannel(profileName, channel, this, selector);
|
TcpNioServerSocketChannel nioChannel = new TcpNioServerSocketChannel(profileName, channel, this, selector);
|
||||||
nioServerChannel.setAcceptContext(SecurityNioTransport.this::acceptChannel);
|
ServerChannelContext context = new ServerChannelContext(nioChannel, SecurityNioTransport.this::acceptChannel, (c, e) -> {});
|
||||||
return nioServerChannel;
|
nioChannel.setContext(context);
|
||||||
|
return nioChannel;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -6,9 +6,8 @@
|
||||||
package org.elasticsearch.xpack.security.transport.nio;
|
package org.elasticsearch.xpack.security.transport.nio;
|
||||||
|
|
||||||
import org.elasticsearch.common.util.BigArrays;
|
import org.elasticsearch.common.util.BigArrays;
|
||||||
import org.elasticsearch.nio.BytesChannelContext;
|
import org.elasticsearch.nio.SocketChannelContext;
|
||||||
import org.elasticsearch.nio.BytesWriteOperation;
|
import org.elasticsearch.nio.BytesWriteOperation;
|
||||||
import org.elasticsearch.nio.ChannelContext;
|
|
||||||
import org.elasticsearch.nio.InboundChannelBuffer;
|
import org.elasticsearch.nio.InboundChannelBuffer;
|
||||||
import org.elasticsearch.nio.NioSocketChannel;
|
import org.elasticsearch.nio.NioSocketChannel;
|
||||||
import org.elasticsearch.nio.SocketSelector;
|
import org.elasticsearch.nio.SocketSelector;
|
||||||
|
@ -21,6 +20,7 @@ import org.mockito.stubbing.Answer;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.nio.channels.ClosedChannelException;
|
import java.nio.channels.ClosedChannelException;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
import java.util.function.BiConsumer;
|
import java.util.function.BiConsumer;
|
||||||
import java.util.function.Supplier;
|
import java.util.function.Supplier;
|
||||||
|
|
||||||
|
@ -35,7 +35,7 @@ import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
public class SSLChannelContextTests extends ESTestCase {
|
public class SSLChannelContextTests extends ESTestCase {
|
||||||
|
|
||||||
private ChannelContext.ReadConsumer readConsumer;
|
private SocketChannelContext.ReadConsumer readConsumer;
|
||||||
private NioSocketChannel channel;
|
private NioSocketChannel channel;
|
||||||
private SSLChannelContext context;
|
private SSLChannelContext context;
|
||||||
private InboundChannelBuffer channelBuffer;
|
private InboundChannelBuffer channelBuffer;
|
||||||
|
@ -49,18 +49,15 @@ public class SSLChannelContextTests extends ESTestCase {
|
||||||
@Before
|
@Before
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
public void init() {
|
public void init() {
|
||||||
readConsumer = mock(ChannelContext.ReadConsumer.class);
|
readConsumer = mock(SocketChannelContext.ReadConsumer.class);
|
||||||
|
|
||||||
messageLength = randomInt(96) + 20;
|
messageLength = randomInt(96) + 20;
|
||||||
selector = mock(SocketSelector.class);
|
selector = mock(SocketSelector.class);
|
||||||
listener = mock(BiConsumer.class);
|
listener = mock(BiConsumer.class);
|
||||||
channel = mock(NioSocketChannel.class);
|
channel = mock(NioSocketChannel.class);
|
||||||
sslDriver = mock(SSLDriver.class);
|
sslDriver = mock(SSLDriver.class);
|
||||||
Supplier<InboundChannelBuffer.Page> pageSupplier = () ->
|
channelBuffer = InboundChannelBuffer.allocatingInstance();
|
||||||
new InboundChannelBuffer.Page(ByteBuffer.allocate(BigArrays.BYTE_PAGE_SIZE), () -> {
|
context = new SSLChannelContext(channel, mock(BiConsumer.class), sslDriver, readConsumer, channelBuffer);
|
||||||
});
|
|
||||||
channelBuffer = new InboundChannelBuffer(pageSupplier);
|
|
||||||
context = new SSLChannelContext(channel, sslDriver, readConsumer, channelBuffer);
|
|
||||||
|
|
||||||
when(channel.getSelector()).thenReturn(selector);
|
when(channel.getSelector()).thenReturn(selector);
|
||||||
when(selector.isOnCurrentThread()).thenReturn(true);
|
when(selector.isOnCurrentThread()).thenReturn(true);
|
||||||
|
@ -145,14 +142,17 @@ public class SSLChannelContextTests extends ESTestCase {
|
||||||
assertTrue(context.selectorShouldClose());
|
assertTrue(context.selectorShouldClose());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
public void testCloseClosesChannelBuffer() throws IOException {
|
public void testCloseClosesChannelBuffer() throws IOException {
|
||||||
Runnable closer = mock(Runnable.class);
|
AtomicInteger closeCount = new AtomicInteger(0);
|
||||||
Supplier<InboundChannelBuffer.Page> pageSupplier = () -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14), closer);
|
Supplier<InboundChannelBuffer.Page> pageSupplier = () -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14),
|
||||||
|
closeCount::incrementAndGet);
|
||||||
InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier);
|
InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier);
|
||||||
buffer.ensureCapacity(1);
|
buffer.ensureCapacity(1);
|
||||||
BytesChannelContext context = new BytesChannelContext(channel, readConsumer, buffer);
|
SSLChannelContext context = new SSLChannelContext(channel, mock(BiConsumer.class), sslDriver, readConsumer, buffer);
|
||||||
|
when(channel.isOpen()).thenReturn(true);
|
||||||
context.closeFromSelector();
|
context.closeFromSelector();
|
||||||
verify(closer).run();
|
assertEquals(1, closeCount.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testWriteOpsClearedOnClose() throws IOException {
|
public void testWriteOpsClearedOnClose() throws IOException {
|
||||||
|
@ -164,6 +164,7 @@ public class SSLChannelContextTests extends ESTestCase {
|
||||||
when(sslDriver.readyForApplicationWrites()).thenReturn(true);
|
when(sslDriver.readyForApplicationWrites()).thenReturn(true);
|
||||||
assertTrue(context.hasQueuedWriteOps());
|
assertTrue(context.hasQueuedWriteOps());
|
||||||
|
|
||||||
|
when(channel.isOpen()).thenReturn(true);
|
||||||
context.closeFromSelector();
|
context.closeFromSelector();
|
||||||
|
|
||||||
verify(selector).executeFailedListener(same(listener), any(ClosedChannelException.class));
|
verify(selector).executeFailedListener(same(listener), any(ClosedChannelException.class));
|
||||||
|
@ -172,6 +173,7 @@ public class SSLChannelContextTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testSSLDriverClosedOnClose() throws IOException {
|
public void testSSLDriverClosedOnClose() throws IOException {
|
||||||
|
when(channel.isOpen()).thenReturn(true);
|
||||||
context.closeFromSelector();
|
context.closeFromSelector();
|
||||||
|
|
||||||
verify(sslDriver).close();
|
verify(sslDriver).close();
|
||||||
|
|
Loading…
Reference in New Issue