Support changes in nio channel contexts (elastic/x-pack-elasticsearch#3609)
This is related to elastic/elasticsearch#elastic/x-pack-elasticsearch#28275. It modifies x-pack to support the changes in channel contexts. Additionally, it simplifies the SSLChannelContext by relying on some common work between it and BytesChannelContext. Original commit: elastic/x-pack-elasticsearch@8a8fcce050
This commit is contained in:
parent
e775e84a7e
commit
685b75da3a
|
@ -5,16 +5,18 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.security.transport.nio;
|
||||
|
||||
import org.elasticsearch.nio.SocketChannelContext;
|
||||
import org.elasticsearch.nio.BytesWriteOperation;
|
||||
import org.elasticsearch.nio.ChannelContext;
|
||||
import org.elasticsearch.nio.InboundChannelBuffer;
|
||||
import org.elasticsearch.nio.NioSocketChannel;
|
||||
import org.elasticsearch.nio.SocketSelector;
|
||||
import org.elasticsearch.nio.WriteOperation;
|
||||
import org.elasticsearch.nio.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.channels.ClosedChannelException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.LinkedList;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.function.BiConsumer;
|
||||
|
@ -22,22 +24,20 @@ import java.util.function.BiConsumer;
|
|||
/**
|
||||
* Provides a TLS/SSL read/write layer over a channel. This context will use a {@link SSLDriver} to handshake
|
||||
* with the peer channel. Once the handshake is complete, any data from the peer channel will be decrypted
|
||||
* before being passed to the {@link org.elasticsearch.nio.ChannelContext.ReadConsumer}. Outbound data will
|
||||
* before being passed to the {@link ReadConsumer}. Outbound data will
|
||||
* be encrypted before being flushed to the channel.
|
||||
*/
|
||||
public final class SSLChannelContext implements ChannelContext {
|
||||
public final class SSLChannelContext extends SocketChannelContext {
|
||||
|
||||
private final NioSocketChannel channel;
|
||||
private final LinkedList<BytesWriteOperation> queued = new LinkedList<>();
|
||||
private final SSLDriver sslDriver;
|
||||
private final ReadConsumer readConsumer;
|
||||
private final InboundChannelBuffer buffer;
|
||||
private final AtomicBoolean isClosing = new AtomicBoolean(false);
|
||||
private boolean peerClosed = false;
|
||||
private boolean ioException = false;
|
||||
|
||||
SSLChannelContext(NioSocketChannel channel, SSLDriver sslDriver, ReadConsumer readConsumer, InboundChannelBuffer buffer) {
|
||||
this.channel = channel;
|
||||
SSLChannelContext(NioSocketChannel channel, BiConsumer<NioSocketChannel, Exception> exceptionHandler, SSLDriver sslDriver,
|
||||
ReadConsumer readConsumer, InboundChannelBuffer buffer) {
|
||||
super(channel, exceptionHandler);
|
||||
this.sslDriver = sslDriver;
|
||||
this.readConsumer = readConsumer;
|
||||
this.buffer = buffer;
|
||||
|
@ -64,7 +64,6 @@ public final class SSLChannelContext implements ChannelContext {
|
|||
return;
|
||||
}
|
||||
|
||||
// TODO: Eval if we will allow writes from sendMessage
|
||||
selector.queueWriteInChannelBuffer(writeOperation);
|
||||
}
|
||||
|
||||
|
@ -80,14 +79,14 @@ public final class SSLChannelContext implements ChannelContext {
|
|||
|
||||
@Override
|
||||
public void flushChannel() throws IOException {
|
||||
if (ioException) {
|
||||
if (hasIOException()) {
|
||||
return;
|
||||
}
|
||||
// If there is currently data in the outbound write buffer, flush the buffer.
|
||||
if (sslDriver.hasFlushPending()) {
|
||||
internalFlush();
|
||||
// If the data is not completely flushed, exit. We cannot produce new write data until the
|
||||
// existing data has been fully flushed.
|
||||
flushToChannel(sslDriver.getNetworkWriteBuffer());
|
||||
if (sslDriver.hasFlushPending()) {
|
||||
return;
|
||||
}
|
||||
|
@ -113,7 +112,7 @@ public final class SSLChannelContext implements ChannelContext {
|
|||
}
|
||||
currentOperation.incrementIndex(bytesEncrypted);
|
||||
// Flush the write buffer to the channel
|
||||
internalFlush();
|
||||
flushToChannel(sslDriver.getNetworkWriteBuffer());
|
||||
} catch (IOException e) {
|
||||
queued.removeFirst();
|
||||
channel.getSelector().executeFailedListener(currentOperation.getListener(), e);
|
||||
|
@ -128,21 +127,12 @@ public final class SSLChannelContext implements ChannelContext {
|
|||
sslDriver.nonApplicationWrite();
|
||||
// If non-application writes were produced, flush the outbound write buffer.
|
||||
if (sslDriver.hasFlushPending()) {
|
||||
internalFlush();
|
||||
flushToChannel(sslDriver.getNetworkWriteBuffer());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private int internalFlush() throws IOException {
|
||||
try {
|
||||
return channel.write(sslDriver.getNetworkWriteBuffer());
|
||||
} catch (IOException e) {
|
||||
ioException = true;
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasQueuedWriteOps() {
|
||||
channel.getSelector().assertOnSelectorThread();
|
||||
|
@ -156,18 +146,12 @@ public final class SSLChannelContext implements ChannelContext {
|
|||
@Override
|
||||
public int read() throws IOException {
|
||||
int bytesRead = 0;
|
||||
if (ioException) {
|
||||
if (hasIOException()) {
|
||||
return bytesRead;
|
||||
}
|
||||
try {
|
||||
bytesRead = channel.read(sslDriver.getNetworkReadBuffer());
|
||||
} catch (IOException e) {
|
||||
ioException = true;
|
||||
throw e;
|
||||
}
|
||||
if (bytesRead < 0) {
|
||||
peerClosed = true;
|
||||
return 0;
|
||||
bytesRead = readFromChannel(sslDriver.getNetworkReadBuffer());
|
||||
if (bytesRead == 0) {
|
||||
return bytesRead;
|
||||
}
|
||||
|
||||
sslDriver.read(buffer);
|
||||
|
@ -183,7 +167,7 @@ public final class SSLChannelContext implements ChannelContext {
|
|||
|
||||
@Override
|
||||
public boolean selectorShouldClose() {
|
||||
return peerClosed || ioException || sslDriver.isClosed();
|
||||
return isPeerClosed() || hasIOException() || sslDriver.isClosed();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -202,14 +186,27 @@ public final class SSLChannelContext implements ChannelContext {
|
|||
@Override
|
||||
public void closeFromSelector() throws IOException {
|
||||
channel.getSelector().assertOnSelectorThread();
|
||||
// Set to true in order to reject new writes before queuing with selector
|
||||
isClosing.set(true);
|
||||
buffer.close();
|
||||
for (BytesWriteOperation op : queued) {
|
||||
channel.getSelector().executeFailedListener(op.getListener(), new ClosedChannelException());
|
||||
if (channel.isOpen()) {
|
||||
// Set to true in order to reject new writes before queuing with selector
|
||||
isClosing.set(true);
|
||||
ArrayList<IOException> closingExceptions = new ArrayList<>(2);
|
||||
try {
|
||||
channel.closeFromSelector();
|
||||
} catch (IOException e) {
|
||||
closingExceptions.add(e);
|
||||
}
|
||||
try {
|
||||
buffer.close();
|
||||
for (BytesWriteOperation op : queued) {
|
||||
channel.getSelector().executeFailedListener(op.getListener(), new ClosedChannelException());
|
||||
}
|
||||
queued.clear();
|
||||
sslDriver.close();
|
||||
} catch (IOException e) {
|
||||
closingExceptions.add(e);
|
||||
}
|
||||
ExceptionsHelper.rethrowAndSuppress(closingExceptions);
|
||||
}
|
||||
queued.clear();
|
||||
sslDriver.close();
|
||||
}
|
||||
|
||||
private static class CloseNotifyOperation implements WriteOperation {
|
||||
|
|
|
@ -13,10 +13,11 @@ import org.elasticsearch.common.settings.Settings;
|
|||
import org.elasticsearch.common.util.BigArrays;
|
||||
import org.elasticsearch.common.util.PageCacheRecycler;
|
||||
import org.elasticsearch.indices.breaker.CircuitBreakerService;
|
||||
import org.elasticsearch.nio.SocketChannelContext;
|
||||
import org.elasticsearch.nio.AcceptingSelector;
|
||||
import org.elasticsearch.nio.ChannelContext;
|
||||
import org.elasticsearch.nio.InboundChannelBuffer;
|
||||
import org.elasticsearch.nio.NioSocketChannel;
|
||||
import org.elasticsearch.nio.ServerChannelContext;
|
||||
import org.elasticsearch.nio.SocketSelector;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.transport.TcpTransport;
|
||||
|
@ -36,6 +37,7 @@ import java.nio.channels.SocketChannel;
|
|||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.function.BiConsumer;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
import static org.elasticsearch.xpack.security.SecurityField.setting;
|
||||
|
@ -125,19 +127,21 @@ public class SecurityNioTransport extends NioTransport {
|
|||
return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close);
|
||||
};
|
||||
|
||||
ChannelContext.ReadConsumer nioReadConsumer = channelBuffer ->
|
||||
SocketChannelContext.ReadConsumer nioReadConsumer = channelBuffer ->
|
||||
consumeNetworkReads(nioChannel, BytesReference.fromByteBuffers(channelBuffer.sliceBuffersTo(channelBuffer.getIndex())));
|
||||
InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier);
|
||||
SSLChannelContext context = new SSLChannelContext(nioChannel, sslDriver, nioReadConsumer, buffer);
|
||||
nioChannel.setContexts(context, SecurityNioTransport.this::exceptionCaught);
|
||||
BiConsumer<NioSocketChannel, Exception> exceptionHandler = SecurityNioTransport.this::exceptionCaught;
|
||||
SSLChannelContext context = new SSLChannelContext(nioChannel, exceptionHandler, sslDriver, nioReadConsumer, buffer);
|
||||
nioChannel.setContext(context);
|
||||
return nioChannel;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TcpNioServerSocketChannel createServerChannel(AcceptingSelector selector, ServerSocketChannel channel) throws IOException {
|
||||
TcpNioServerSocketChannel nioServerChannel = new TcpNioServerSocketChannel(profileName, channel, this, selector);
|
||||
nioServerChannel.setAcceptContext(SecurityNioTransport.this::acceptChannel);
|
||||
return nioServerChannel;
|
||||
TcpNioServerSocketChannel nioChannel = new TcpNioServerSocketChannel(profileName, channel, this, selector);
|
||||
ServerChannelContext context = new ServerChannelContext(nioChannel, SecurityNioTransport.this::acceptChannel, (c, e) -> {});
|
||||
nioChannel.setContext(context);
|
||||
return nioChannel;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -6,9 +6,8 @@
|
|||
package org.elasticsearch.xpack.security.transport.nio;
|
||||
|
||||
import org.elasticsearch.common.util.BigArrays;
|
||||
import org.elasticsearch.nio.BytesChannelContext;
|
||||
import org.elasticsearch.nio.SocketChannelContext;
|
||||
import org.elasticsearch.nio.BytesWriteOperation;
|
||||
import org.elasticsearch.nio.ChannelContext;
|
||||
import org.elasticsearch.nio.InboundChannelBuffer;
|
||||
import org.elasticsearch.nio.NioSocketChannel;
|
||||
import org.elasticsearch.nio.SocketSelector;
|
||||
|
@ -21,6 +20,7 @@ import org.mockito.stubbing.Answer;
|
|||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.channels.ClosedChannelException;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.function.BiConsumer;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
|
@ -35,7 +35,7 @@ import static org.mockito.Mockito.when;
|
|||
|
||||
public class SSLChannelContextTests extends ESTestCase {
|
||||
|
||||
private ChannelContext.ReadConsumer readConsumer;
|
||||
private SocketChannelContext.ReadConsumer readConsumer;
|
||||
private NioSocketChannel channel;
|
||||
private SSLChannelContext context;
|
||||
private InboundChannelBuffer channelBuffer;
|
||||
|
@ -49,18 +49,15 @@ public class SSLChannelContextTests extends ESTestCase {
|
|||
@Before
|
||||
@SuppressWarnings("unchecked")
|
||||
public void init() {
|
||||
readConsumer = mock(ChannelContext.ReadConsumer.class);
|
||||
readConsumer = mock(SocketChannelContext.ReadConsumer.class);
|
||||
|
||||
messageLength = randomInt(96) + 20;
|
||||
selector = mock(SocketSelector.class);
|
||||
listener = mock(BiConsumer.class);
|
||||
channel = mock(NioSocketChannel.class);
|
||||
sslDriver = mock(SSLDriver.class);
|
||||
Supplier<InboundChannelBuffer.Page> pageSupplier = () ->
|
||||
new InboundChannelBuffer.Page(ByteBuffer.allocate(BigArrays.BYTE_PAGE_SIZE), () -> {
|
||||
});
|
||||
channelBuffer = new InboundChannelBuffer(pageSupplier);
|
||||
context = new SSLChannelContext(channel, sslDriver, readConsumer, channelBuffer);
|
||||
channelBuffer = InboundChannelBuffer.allocatingInstance();
|
||||
context = new SSLChannelContext(channel, mock(BiConsumer.class), sslDriver, readConsumer, channelBuffer);
|
||||
|
||||
when(channel.getSelector()).thenReturn(selector);
|
||||
when(selector.isOnCurrentThread()).thenReturn(true);
|
||||
|
@ -145,14 +142,17 @@ public class SSLChannelContextTests extends ESTestCase {
|
|||
assertTrue(context.selectorShouldClose());
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testCloseClosesChannelBuffer() throws IOException {
|
||||
Runnable closer = mock(Runnable.class);
|
||||
Supplier<InboundChannelBuffer.Page> pageSupplier = () -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14), closer);
|
||||
AtomicInteger closeCount = new AtomicInteger(0);
|
||||
Supplier<InboundChannelBuffer.Page> pageSupplier = () -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14),
|
||||
closeCount::incrementAndGet);
|
||||
InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier);
|
||||
buffer.ensureCapacity(1);
|
||||
BytesChannelContext context = new BytesChannelContext(channel, readConsumer, buffer);
|
||||
SSLChannelContext context = new SSLChannelContext(channel, mock(BiConsumer.class), sslDriver, readConsumer, buffer);
|
||||
when(channel.isOpen()).thenReturn(true);
|
||||
context.closeFromSelector();
|
||||
verify(closer).run();
|
||||
assertEquals(1, closeCount.get());
|
||||
}
|
||||
|
||||
public void testWriteOpsClearedOnClose() throws IOException {
|
||||
|
@ -164,6 +164,7 @@ public class SSLChannelContextTests extends ESTestCase {
|
|||
when(sslDriver.readyForApplicationWrites()).thenReturn(true);
|
||||
assertTrue(context.hasQueuedWriteOps());
|
||||
|
||||
when(channel.isOpen()).thenReturn(true);
|
||||
context.closeFromSelector();
|
||||
|
||||
verify(selector).executeFailedListener(same(listener), any(ClosedChannelException.class));
|
||||
|
@ -172,6 +173,7 @@ public class SSLChannelContextTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testSSLDriverClosedOnClose() throws IOException {
|
||||
when(channel.isOpen()).thenReturn(true);
|
||||
context.closeFromSelector();
|
||||
|
||||
verify(sslDriver).close();
|
||||
|
|
Loading…
Reference in New Issue