Add TLS/SSL enabled SecurityNioTransport (elastic/x-pack-elasticsearch#3519)

This is related to elastic/x-pack-elasticsearch#3246. This commit adds a SSL/TLS layer to the nio
work implemented in the SSLChannelContext and SSLDriver classes.
This work is used to build up a SecurityNioTransport implementation.
This transport does yet offer feature parity with our normal security
transport. It mainly offers SSL/TLS security.

Original commit: elastic/x-pack-elasticsearch@d0e0484418
This commit is contained in:
Tim Brooks 2018-01-17 09:44:31 -07:00 committed by GitHub
parent a4fad02d9a
commit dda3a8dee0
7 changed files with 1871 additions and 0 deletions

View File

@ -29,11 +29,15 @@ dependencyLicenses {
mapping from: /bc.*/, to: 'bouncycastle'
mapping from: /owasp-java-html-sanitizer.*/, to: 'owasp-java-html-sanitizer'
mapping from: /transport-netty.*/, to: 'elasticsearch'
mapping from: /transport-nio.*/, to: 'elasticsearch'
mapping from: /elasticsearch-nio.*/, to: 'elasticsearch'
mapping from: /elasticsearch-rest-client.*/, to: 'elasticsearch'
mapping from: /http.*/, to: 'httpclient' // pulled in by rest client
mapping from: /commons-.*/, to: 'commons' // pulled in by rest client
ignoreSha 'elasticsearch-rest-client'
ignoreSha 'transport-netty4'
ignoreSha 'transport-nio'
ignoreSha 'elasticsearch-nio'
ignoreSha 'elasticsearch-rest-client-sniffer'
ignoreSha 'x-pack-core'
}
@ -59,6 +63,7 @@ dependencies {
// security deps
compile project(path: ':modules:transport-netty4', configuration: 'runtime')
compile project(path: ':plugins:transport-nio', configuration: 'runtime')
compile 'com.unboundid:unboundid-ldapsdk:3.2.0'
compile 'org.bouncycastle:bcprov-jdk15on:1.58'
compile 'org.bouncycastle:bcpkix-jdk15on:1.58'

View File

@ -0,0 +1,234 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.security.transport.nio;
import org.elasticsearch.nio.BytesWriteOperation;
import org.elasticsearch.nio.ChannelContext;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.SocketSelector;
import org.elasticsearch.nio.WriteOperation;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.util.LinkedList;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
/**
* Provides a TLS/SSL read/write layer over a channel. This context will use a {@link SSLDriver} to handshake
* with the peer channel. Once the handshake is complete, any data from the peer channel will be decrypted
* before being passed to the {@link org.elasticsearch.nio.ChannelContext.ReadConsumer}. Outbound data will
* be encrypted before being flushed to the channel.
*/
public final class SSLChannelContext implements ChannelContext {
private final NioSocketChannel channel;
private final LinkedList<BytesWriteOperation> queued = new LinkedList<>();
private final SSLDriver sslDriver;
private final ReadConsumer readConsumer;
private final InboundChannelBuffer buffer;
private final AtomicBoolean isClosing = new AtomicBoolean(false);
private boolean peerClosed = false;
private boolean ioException = false;
SSLChannelContext(NioSocketChannel channel, SSLDriver sslDriver, ReadConsumer readConsumer, InboundChannelBuffer buffer) {
this.channel = channel;
this.sslDriver = sslDriver;
this.readConsumer = readConsumer;
this.buffer = buffer;
}
@Override
public void channelRegistered() throws IOException {
sslDriver.init();
}
@Override
public void sendMessage(ByteBuffer[] buffers, BiConsumer<Void, Throwable> listener) {
if (isClosing.get()) {
listener.accept(null, new ClosedChannelException());
return;
}
BytesWriteOperation writeOperation = new BytesWriteOperation(channel, buffers, listener);
SocketSelector selector = channel.getSelector();
if (selector.isOnCurrentThread() == false) {
// If this message is being sent from another thread, we queue the write to be handled by the
// network thread
selector.queueWrite(writeOperation);
return;
}
// TODO: Eval if we will allow writes from sendMessage
selector.queueWriteInChannelBuffer(writeOperation);
}
@Override
public void queueWriteOperation(WriteOperation writeOperation) {
channel.getSelector().assertOnSelectorThread();
if (writeOperation instanceof CloseNotifyOperation) {
sslDriver.initiateClose();
} else {
queued.add((BytesWriteOperation) writeOperation);
}
}
@Override
public void flushChannel() throws IOException {
if (ioException) {
return;
}
// If there is currently data in the outbound write buffer, flush the buffer.
if (sslDriver.hasFlushPending()) {
internalFlush();
// If the data is not completely flushed, exit. We cannot produce new write data until the
// existing data has been fully flushed.
if (sslDriver.hasFlushPending()) {
return;
}
}
// If the driver is ready for application writes, we can attempt to proceed with any queued writes.
if (sslDriver.readyForApplicationWrites()) {
BytesWriteOperation currentOperation = queued.peekFirst();
while (sslDriver.hasFlushPending() == false && currentOperation != null) {
// If the current operation has been fully consumed (encrypted) we now know that it has been
// sent (as we only get to this point if the write buffer has been fully flushed).
if (currentOperation.isFullyFlushed()) {
queued.removeFirst();
channel.getSelector().executeListener(currentOperation.getListener(), null);
currentOperation = queued.peekFirst();
} else {
try {
// Attempt to encrypt application write data. The encrypted data ends up in the
// outbound write buffer.
int bytesEncrypted = sslDriver.applicationWrite(currentOperation.getBuffersToWrite());
if (bytesEncrypted == 0) {
break;
}
currentOperation.incrementIndex(bytesEncrypted);
// Flush the write buffer to the channel
internalFlush();
} catch (IOException e) {
queued.removeFirst();
channel.getSelector().executeFailedListener(currentOperation.getListener(), e);
throw e;
}
}
}
} else {
// We are not ready for application writes, check if the driver has non-application writes. We
// only want to continue producing new writes if the outbound write buffer is fully flushed.
while (sslDriver.hasFlushPending() == false && sslDriver.needsNonApplicationWrite()) {
sslDriver.nonApplicationWrite();
// If non-application writes were produced, flush the outbound write buffer.
if (sslDriver.hasFlushPending()) {
internalFlush();
}
}
}
}
private int internalFlush() throws IOException {
try {
return channel.write(sslDriver.getNetworkWriteBuffer());
} catch (IOException e) {
ioException = true;
throw e;
}
}
@Override
public boolean hasQueuedWriteOps() {
channel.getSelector().assertOnSelectorThread();
if (sslDriver.readyForApplicationWrites()) {
return sslDriver.hasFlushPending() || queued.isEmpty() == false;
} else {
return sslDriver.hasFlushPending() || sslDriver.needsNonApplicationWrite();
}
}
@Override
public int read() throws IOException {
int bytesRead = 0;
if (ioException) {
return bytesRead;
}
try {
bytesRead = channel.read(sslDriver.getNetworkReadBuffer());
} catch (IOException e) {
ioException = true;
throw e;
}
if (bytesRead < 0) {
peerClosed = true;
return 0;
}
sslDriver.read(buffer);
int bytesConsumed = Integer.MAX_VALUE;
while (bytesConsumed > 0 && buffer.getIndex() > 0) {
bytesConsumed = readConsumer.consumeReads(buffer);
buffer.release(bytesConsumed);
}
return bytesRead;
}
@Override
public boolean selectorShouldClose() {
return peerClosed || ioException || sslDriver.isClosed();
}
@Override
public void closeChannel() {
if (isClosing.compareAndSet(false, true)) {
WriteOperation writeOperation = new CloseNotifyOperation(channel);
SocketSelector selector = channel.getSelector();
if (selector.isOnCurrentThread() == false) {
selector.queueWrite(writeOperation);
return;
}
selector.queueWriteInChannelBuffer(writeOperation);
}
}
@Override
public void closeFromSelector() throws IOException {
channel.getSelector().assertOnSelectorThread();
// Set to true in order to reject new writes before queuing with selector
isClosing.set(true);
buffer.close();
for (BytesWriteOperation op : queued) {
channel.getSelector().executeFailedListener(op.getListener(), new ClosedChannelException());
}
queued.clear();
sslDriver.close();
}
private static class CloseNotifyOperation implements WriteOperation {
private static final BiConsumer<Void, Throwable> LISTENER = (v, t) -> {};
private final NioSocketChannel channel;
private CloseNotifyOperation(NioSocketChannel channel) {
this.channel = channel;
}
@Override
public BiConsumer<Void, Throwable> getListener() {
return LISTENER;
}
@Override
public NioSocketChannel getChannel() {
return channel;
}
}
}

View File

@ -0,0 +1,567 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.security.transport.nio;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.utils.ExceptionsHelper;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;
import java.nio.ByteBuffer;
import java.util.ArrayList;
/**
* SSLDriver is a class that wraps the {@link SSLEngine} and attempts to simplify the API. The basic usage is
* to create an SSLDriver class and call {@link #init()}. This initiates the SSL/TLS handshaking process.
*
* When the SSLDriver is handshaking or closing, reads and writes will be consumed/produced internally to
* advance the handshake or close process. Alternatively, when the SSLDriver is in application mode, it will
* decrypt data off the wire to be consumed by the application and will encrypt data provided by the
* application to be written to the wire.
*
* Handling reads from a channel with this class is very simple. When data has been read, call
* {@link #read(InboundChannelBuffer)}. If the data is application data, it will be decrypted and placed into
* the buffer passed as an argument. Otherwise, it will be consumed internally and advance the SSL/TLS close
* or handshake process.
*
* Producing writes for a channel is more complicated. If there is existing data in the outbound write buffer
* as indicated by {@link #hasFlushPending()}, that data must be written to the channel before more outbound
* data can be produced. If no flushes are pending, {@link #needsNonApplicationWrite()} can be called to
* determine if this driver needs to produce more data to advance the handshake or close process. If that
* method returns true, {@link #nonApplicationWrite()} should be called (and the data produced then flushed
* to the channel) until no further non-application writes are needed.
*
* If no non-application writes are needed, {@link #readyForApplicationWrites()} can be called to determine
* if the driver is ready to consume application data. (Note: It is possible that
* {@link #readyForApplicationWrites()} and {@link #needsNonApplicationWrite()} can both return false if the
* driver is waiting on non-application data from the peer.) If the driver indicates it is ready for
* application writes, {@link #applicationWrite(ByteBuffer[])} can be called. This method will encrypt
* application data and place it in the write buffer for flushing to a channel.
*
* If you are ready to close the channel {@link #initiateClose()} should be called. After that is called, the
* driver will start producing non-application writes related to notifying the peer connection that this
* connection is closing. When {@link #isClosed()} returns true, this SSL connection is closed and the
* channel should be closed.
*/
public class SSLDriver implements AutoCloseable {
private static final ByteBuffer[] EMPTY_BUFFER_ARRAY = new ByteBuffer[0];
private final SSLEngine engine;
private final boolean isClientMode;
// This should only be accessed by the network thread associated with this channel, so nothing needs to
// be volatile.
private Mode currentMode = new HandshakeMode();
private ByteBuffer networkWriteBuffer;
private ByteBuffer networkReadBuffer;
public SSLDriver(SSLEngine engine, boolean isClientMode) {
this.engine = engine;
this.isClientMode = isClientMode;
SSLSession session = engine.getSession();
this.networkReadBuffer = ByteBuffer.allocate(session.getPacketBufferSize());
this.networkWriteBuffer = ByteBuffer.allocate(session.getPacketBufferSize());
this.networkWriteBuffer.position(this.networkWriteBuffer.limit());
}
public void init() throws SSLException {
engine.setUseClientMode(isClientMode);
if (currentMode.isHandshake()) {
((HandshakeMode) currentMode).startHandshake();
} else {
throw new AssertionError("Attempted to init outside from non-handshaking mode: " + currentMode.modeName());
}
}
public boolean hasFlushPending() {
return networkWriteBuffer.hasRemaining();
}
public boolean isHandshaking() {
return currentMode.isHandshake();
}
public ByteBuffer getNetworkWriteBuffer() {
return networkWriteBuffer;
}
public ByteBuffer getNetworkReadBuffer() {
return networkReadBuffer;
}
public void read(InboundChannelBuffer buffer) throws SSLException {
currentMode.read(buffer);
}
public boolean readyForApplicationWrites() {
return currentMode.isApplication();
}
public boolean needsNonApplicationWrite() {
return currentMode.needsNonApplicationWrite();
}
public int applicationWrite(ByteBuffer[] buffers) throws SSLException {
assert readyForApplicationWrites() : "Should not be called if driver is not ready for application writes";
return currentMode.write(buffers);
}
public void nonApplicationWrite() throws SSLException {
assert currentMode.isApplication() == false : "Should not be called if driver is in application mode";
if (currentMode.isApplication() == false) {
currentMode.write(EMPTY_BUFFER_ARRAY);
} else {
throw new AssertionError("Attempted to non-application write from invalid mode: " + currentMode.modeName());
}
}
public void initiateClose() {
closingInternal();
}
public boolean isClosed() {
return currentMode.isClose() && ((CloseMode) currentMode).isCloseDone();
}
@Override
public void close() throws SSLException {
ArrayList<SSLException> closingExceptions = new ArrayList<>(2);
closingInternal();
CloseMode closeMode = (CloseMode) this.currentMode;
if (closeMode.needToSendClose) {
closingExceptions.add(new SSLException("Closed engine without completely sending the close alert message."));
engine.closeOutbound();
}
if (closeMode.needToReceiveClose) {
closingExceptions.add(new SSLException("Closed engine without receiving the close alert message."));
closeMode.closeInboundAndSwallowPeerDidNotCloseException();
}
ExceptionsHelper.rethrowAndSuppress(closingExceptions);
}
private SSLEngineResult unwrap(InboundChannelBuffer buffer) throws SSLException {
while (true) {
SSLEngineResult result = engine.unwrap(networkReadBuffer, buffer.sliceBuffersFrom(buffer.getIndex()));
buffer.incrementIndex(result.bytesProduced());
switch (result.getStatus()) {
case OK:
networkReadBuffer.compact();
return result;
case BUFFER_UNDERFLOW:
// There is not enough space in the network buffer for an entire SSL packet. Compact the
// current data and expand the buffer if necessary.
int currentCapacity = networkReadBuffer.capacity();
ensureNetworkReadBufferSize();
if (currentCapacity == networkReadBuffer.capacity()) {
networkReadBuffer.compact();
}
return result;
case BUFFER_OVERFLOW:
// There is not enough space in the application buffer for the decrypted message. Expand
// the application buffer to ensure that it has enough space.
ensureApplicationBufferSize(buffer);
break;
case CLOSED:
assert engine.isInboundDone() : "We received close_notify so read should be done";
closingInternal();
return result;
default:
throw new IllegalStateException("Unexpected UNWRAP result: " + result.getStatus());
}
}
}
private SSLEngineResult wrap(ByteBuffer[] buffers) throws SSLException {
assert hasFlushPending() == false : "Should never called with pending writes";
networkWriteBuffer.clear();
while (true) {
SSLEngineResult result;
try {
if (buffers.length == 1) {
result = engine.wrap(buffers[0], networkWriteBuffer);
} else {
result = engine.wrap(buffers, networkWriteBuffer);
}
} catch (SSLException e) {
networkWriteBuffer.position(networkWriteBuffer.limit());
throw e;
}
switch (result.getStatus()) {
case OK:
networkWriteBuffer.flip();
return result;
case BUFFER_UNDERFLOW:
throw new IllegalStateException("Should not receive BUFFER_UNDERFLOW on WRAP");
case BUFFER_OVERFLOW:
// There is not enough space in the network buffer for an entire SSL packet. Expand the
// buffer if it's smaller than the current session packet size. Otherwise return and wait
// for existing data to be flushed.
int currentCapacity = networkWriteBuffer.capacity();
ensureNetworkWriteBufferSize();
if (currentCapacity == networkWriteBuffer.capacity()) {
return result;
}
break;
case CLOSED:
if (result.bytesProduced() > 0) {
networkWriteBuffer.flip();
} else {
assert false : "WRAP during close processing should produce close message.";
}
return result;
default:
throw new IllegalStateException("Unexpected WRAP result: " + result.getStatus());
}
}
}
private boolean checkRenegotiation(SSLEngineResult.HandshakeStatus newStatus) {
if (isHandshaking() == false && newStatus != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING
&& newStatus != SSLEngineResult.HandshakeStatus.FINISHED) {
// TODO: Iron out the specifics of renegotiation
throw new IllegalStateException("We do not support renegotiation");
}
return false;
}
private void closingInternal() {
// This check prevents us from attempting to send close_notify twice
if (currentMode.isClose() == false) {
currentMode = new CloseMode(currentMode.isHandshake());
}
}
private void ensureApplicationBufferSize(InboundChannelBuffer applicationBuffer) {
int applicationBufferSize = engine.getSession().getApplicationBufferSize();
if (applicationBuffer.getRemaining() < applicationBufferSize) {
applicationBuffer.ensureCapacity(applicationBuffer.getIndex() + engine.getSession().getApplicationBufferSize());
}
}
private void ensureNetworkWriteBufferSize() {
networkWriteBuffer = ensureNetBufferSize(networkWriteBuffer);
}
private void ensureNetworkReadBufferSize() {
networkReadBuffer = ensureNetBufferSize(networkReadBuffer);
}
private ByteBuffer ensureNetBufferSize(ByteBuffer current) {
int networkPacketSize = engine.getSession().getPacketBufferSize();
if (current.capacity() < networkPacketSize) {
ByteBuffer newBuffer = ByteBuffer.allocate(networkPacketSize);
current.flip();
newBuffer.put(current);
return newBuffer;
} else {
return current;
}
}
// There are three potential modes for the driver to be in - HANDSHAKE, APPLICATION, or CLOSE. HANDSHAKE
// is the initial mode. During this mode data that is read and written will be related to the TLS
// handshake process. Application related data cannot be encrypted until the handshake is complete. From
// HANDSHAKE mode the driver can transition to APPLICATION (if the handshake is successful) or CLOSE (if
// an error occurs or we initiate a close). In APPLICATION mode data read from the channel will be
// decrypted and placed into the buffer passed as an argument to the read call. Additionally, application
// writes will be accepted and encrypted into the outbound write buffer. APPLICATION mode will proceed
// until we receive a request for renegotiation (currently unsupported) or the CLOSE mode begins. CLOSE
// mode can begin if we receive a CLOSE_NOTIFY message from the peer or if initiateClose is called. In
// CLOSE mode we attempt to both send and receive an SSL CLOSE_NOTIFY message. The exception to this is
// when we enter CLOSE mode from HANDSHAKE mode. In this scenario we only need to send the alert to the
// peer and then close the channel. Some SSL/TLS implementations do not properly adhere to the full
// two-direction close_notify process. Additionally, in newer TLS specifications it is not required to
// wait to receive close_notify. However, we will make our best attempt to both send and receive as it is
// expected by the java SSLEngine (it throws an exception if close_notify has not been received when
// inbound is closed).
private interface Mode {
void read(InboundChannelBuffer buffer) throws SSLException;
int write(ByteBuffer[] buffers) throws SSLException;
boolean needsNonApplicationWrite();
boolean isHandshake();
boolean isApplication();
boolean isClose();
String modeName();
}
private class HandshakeMode implements Mode {
private SSLEngineResult.HandshakeStatus handshakeStatus;
private void startHandshake() throws SSLException {
engine.beginHandshake();
handshakeStatus = engine.getHandshakeStatus();
if (handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP &&
handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_WRAP) {
try {
handshake();
} catch (SSLException e) {
closingInternal();
throw e;
}
}
}
private void handshake() throws SSLException {
boolean continueHandshaking = true;
while (continueHandshaking) {
switch (handshakeStatus) {
case NEED_UNWRAP:
// We UNWRAP as much as possible immediately after a read. Do not need to do it here.
continueHandshaking = false;
break;
case NEED_WRAP:
if (hasFlushPending() == false) {
handshakeStatus = wrap(EMPTY_BUFFER_ARRAY).getHandshakeStatus();
}
continueHandshaking = false;
break;
case NEED_TASK:
runTasks();
handshakeStatus = engine.getHandshakeStatus();
break;
case NOT_HANDSHAKING:
maybeFinishHandshake();
continueHandshaking = false;
break;
case FINISHED:
maybeFinishHandshake();
continueHandshaking = false;
break;
}
}
}
@Override
public void read(InboundChannelBuffer buffer) throws SSLException {
boolean continueUnwrap = true;
while (continueUnwrap && networkReadBuffer.position() > 0) {
networkReadBuffer.flip();
try {
SSLEngineResult result = unwrap(buffer);
handshakeStatus = result.getHandshakeStatus();
continueUnwrap = result.bytesConsumed() > 0;
handshake();
} catch (SSLException e) {
closingInternal();
throw e;
}
}
}
@Override
public int write(ByteBuffer[] buffers) throws SSLException {
try {
handshake();
} catch (SSLException e) {
closingInternal();
throw e;
}
return 0;
}
@Override
public boolean needsNonApplicationWrite() {
return handshakeStatus == SSLEngineResult.HandshakeStatus.NEED_WRAP
|| handshakeStatus == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING
|| handshakeStatus == SSLEngineResult.HandshakeStatus.FINISHED;
}
@Override
public boolean isHandshake() {
return true;
}
@Override
public boolean isApplication() {
return false;
}
@Override
public boolean isClose() {
return false;
}
@Override
public String modeName() {
return "HANDSHAKE";
}
private void runTasks() {
Runnable delegatedTask;
while ((delegatedTask = engine.getDelegatedTask()) != null) {
delegatedTask.run();
}
}
private void maybeFinishHandshake() {
// We only acknowledge that we are done handshaking if there are no bytes that need to be written
if (hasFlushPending() == false) {
if (currentMode.isHandshake()) {
currentMode = new ApplicationMode();
} else {
String message = "Attempted to transition to application mode from non-handshaking mode: " + currentMode;
throw new AssertionError(message);
}
}
}
}
private class ApplicationMode implements Mode {
@Override
public void read(InboundChannelBuffer buffer) throws SSLException {
ensureApplicationBufferSize(buffer);
boolean continueUnwrap = true;
while (continueUnwrap && networkReadBuffer.position() > 0) {
networkReadBuffer.flip();
SSLEngineResult result = unwrap(buffer);
boolean renegotiationRequested = result.getStatus() != SSLEngineResult.Status.CLOSED
&& checkRenegotiation(result.getHandshakeStatus());
continueUnwrap = result.bytesProduced() > 0 && renegotiationRequested == false;
}
}
@Override
public int write(ByteBuffer[] buffers) throws SSLException {
SSLEngineResult result = wrap(buffers);
checkRenegotiation(result.getHandshakeStatus());
return result.bytesConsumed();
}
@Override
public boolean needsNonApplicationWrite() {
return false;
}
@Override
public boolean isHandshake() {
return false;
}
@Override
public boolean isApplication() {
return true;
}
@Override
public boolean isClose() {
return false;
}
@Override
public String modeName() {
return "APPLICATION";
}
}
private class CloseMode implements Mode {
private boolean needToSendClose = true;
private boolean needToReceiveClose = true;
private CloseMode(boolean isHandshaking) {
if (isHandshaking && engine.isInboundDone() == false) {
// If we attempt to close during a handshake either we are sending an alert and inbound
// should already be closed or we are sending a close_notify. If we send a close_notify
// the peer will send an handshake error alert. If we attempt to receive the handshake alert,
// the engine will throw an IllegalStateException as it is not in a proper state to receive
// handshake message. Closing inbound immediately after close_notify is the cleanest option.
needToReceiveClose = false;
} else if (engine.isInboundDone()) {
needToReceiveClose = false;
}
if (engine.isOutboundDone()) {
needToSendClose = false;
} else {
engine.closeOutbound();
}
}
@Override
public void read(InboundChannelBuffer buffer) throws SSLException {
ensureApplicationBufferSize(buffer);
boolean continueUnwrap = true;
while (continueUnwrap && networkReadBuffer.position() > 0) {
networkReadBuffer.flip();
SSLEngineResult result = unwrap(buffer);
continueUnwrap = result.bytesProduced() > 0;
}
if (engine.isInboundDone()) {
needToReceiveClose = false;
}
}
@Override
public int write(ByteBuffer[] buffers) throws SSLException {
if (hasFlushPending() == false && engine.isOutboundDone()) {
needToSendClose = false;
// Close inbound if it is still open and we have decided not to wait for response.
if (needToReceiveClose == false && engine.isInboundDone() == false) {
closeInboundAndSwallowPeerDidNotCloseException();
}
} else {
wrap(EMPTY_BUFFER_ARRAY);
assert hasFlushPending() : "Should have produced close message";
}
return 0;
}
@Override
public boolean needsNonApplicationWrite() {
return needToSendClose;
}
@Override
public boolean isHandshake() {
return false;
}
@Override
public boolean isApplication() {
return false;
}
@Override
public boolean isClose() {
return true;
}
@Override
public String modeName() {
return "CLOSE";
}
private boolean isCloseDone() {
return needToSendClose == false && needToReceiveClose == false;
}
private void closeInboundAndSwallowPeerDidNotCloseException() throws SSLException {
try {
engine.closeInbound();
} catch (SSLException e) {
if (e.getMessage().startsWith("Inbound closed before receiving peer's close_notify") == false) {
throw e;
}
}
}
}
}

View File

@ -0,0 +1,143 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.security.transport.nio;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.recycler.Recycler;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.nio.AcceptingSelector;
import org.elasticsearch.nio.ChannelContext;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.SocketSelector;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.transport.nio.NioTransport;
import org.elasticsearch.transport.nio.TcpNioServerSocketChannel;
import org.elasticsearch.transport.nio.TcpNioSocketChannel;
import org.elasticsearch.xpack.XPackSettings;
import org.elasticsearch.xpack.security.transport.netty4.SecurityNetty4Transport;
import org.elasticsearch.xpack.ssl.SSLConfiguration;
import org.elasticsearch.xpack.ssl.SSLService;
import javax.net.ssl.SSLEngine;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Supplier;
import static org.elasticsearch.xpack.security.SecurityField.setting;
/**
* This transport provides a transport based on nio that is secured by SSL/TLS. SSL/TLS is a communications
* protocol that allows two channels to go through a handshake process prior to application data being
* exchanged. The handshake process enables the channels to exchange parameters that will allow them to
* encrypt the application data they exchange.
*
* The specific SSL/TLS parameters and configurations are setup in the {@link SSLService} class. The actual
* implementation of the SSL/TLS layer is in the {@link SSLChannelContext} and {@link SSLDriver} classes.
*/
public class SecurityNioTransport extends NioTransport {
private final SSLConfiguration sslConfiguration;
private final SSLService sslService;
private final Map<String, SSLConfiguration> profileConfiguration;
private final boolean sslEnabled;
SecurityNioTransport(Settings settings, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays,
PageCacheRecycler pageCacheRecycler, NamedWriteableRegistry namedWriteableRegistry,
CircuitBreakerService circuitBreakerService, SSLService sslService) {
super(settings, threadPool, networkService, bigArrays, pageCacheRecycler, namedWriteableRegistry, circuitBreakerService);
this.sslService = sslService;
this.sslEnabled = XPackSettings.TRANSPORT_SSL_ENABLED.get(settings);
final Settings transportSSLSettings = settings.getByPrefix(setting("transport.ssl."));
if (sslEnabled) {
this.sslConfiguration = sslService.sslConfiguration(transportSSLSettings, Settings.EMPTY);
Map<String, Settings> profileSettingsMap = settings.getGroups("transport.profiles.", true);
Map<String, SSLConfiguration> profileConfiguration = new HashMap<>(profileSettingsMap.size() + 1);
for (Map.Entry<String, Settings> entry : profileSettingsMap.entrySet()) {
Settings profileSettings = entry.getValue();
final Settings profileSslSettings = SecurityNetty4Transport.profileSslSettings(profileSettings);
SSLConfiguration configuration = sslService.sslConfiguration(profileSslSettings, transportSSLSettings);
profileConfiguration.put(entry.getKey(), configuration);
}
if (profileConfiguration.containsKey(TcpTransport.DEFAULT_PROFILE) == false) {
profileConfiguration.put(TcpTransport.DEFAULT_PROFILE, sslConfiguration);
}
this.profileConfiguration = Collections.unmodifiableMap(profileConfiguration);
} else {
throw new IllegalArgumentException("Currently only support SSL enabled.");
}
}
@Override
protected TcpChannelFactory channelFactory(ProfileSettings profileSettings, boolean isClient) {
return new SecurityTcpChannelFactory(profileSettings, isClient);
}
@Override
protected void acceptChannel(NioSocketChannel channel) {
super.acceptChannel(channel);
}
@Override
protected void exceptionCaught(NioSocketChannel channel, Exception exception) {
super.exceptionCaught(channel, exception);
}
private class SecurityTcpChannelFactory extends TcpChannelFactory {
private final String profileName;
private final boolean isClient;
private SecurityTcpChannelFactory(ProfileSettings profileSettings, boolean isClient) {
super(new RawChannelFactory(profileSettings.tcpNoDelay,
profileSettings.tcpKeepAlive,
profileSettings.reuseAddress,
Math.toIntExact(profileSettings.sendBufferSize.getBytes()),
Math.toIntExact(profileSettings.receiveBufferSize.getBytes())));
this.profileName = profileSettings.profileName;
this.isClient = isClient;
}
@Override
public TcpNioSocketChannel createChannel(SocketSelector selector, SocketChannel channel) throws IOException {
SSLConfiguration defaultConfig = profileConfiguration.get(TcpTransport.DEFAULT_PROFILE);
SSLEngine sslEngine = sslService.createSSLEngine(profileConfiguration.getOrDefault(profileName, defaultConfig), null, -1);
SSLDriver sslDriver = new SSLDriver(sslEngine, isClient);
TcpNioSocketChannel nioChannel = new TcpNioSocketChannel(profileName, channel, selector);
Supplier<InboundChannelBuffer.Page> pageSupplier = () -> {
Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close);
};
ChannelContext.ReadConsumer nioReadConsumer = channelBuffer ->
consumeNetworkReads(nioChannel, BytesReference.fromByteBuffers(channelBuffer.sliceBuffersTo(channelBuffer.getIndex())));
InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier);
SSLChannelContext context = new SSLChannelContext(nioChannel, sslDriver, nioReadConsumer, buffer);
nioChannel.setContexts(context, SecurityNioTransport.this::exceptionCaught);
return nioChannel;
}
@Override
public TcpNioServerSocketChannel createServerChannel(AcceptingSelector selector, ServerSocketChannel channel) throws IOException {
TcpNioServerSocketChannel nioServerChannel = new TcpNioServerSocketChannel(profileName, channel, this, selector);
nioServerChannel.setAcceptContext(SecurityNioTransport.this::acceptChannel);
return nioServerChannel;
}
}
}

View File

@ -0,0 +1,447 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.security.transport.nio;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.nio.BytesChannelContext;
import org.elasticsearch.nio.BytesWriteOperation;
import org.elasticsearch.nio.ChannelContext;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.SocketSelector;
import org.elasticsearch.nio.WriteOperation;
import org.elasticsearch.test.ESTestCase;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.isNull;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class SSLChannelContextTests extends ESTestCase {
private ChannelContext.ReadConsumer readConsumer;
private NioSocketChannel channel;
private SSLChannelContext context;
private InboundChannelBuffer channelBuffer;
private SocketSelector selector;
private BiConsumer<Void, Throwable> listener;
private SSLDriver sslDriver;
private ByteBuffer readBuffer = ByteBuffer.allocate(1 << 14);
private ByteBuffer writeBuffer = ByteBuffer.allocate(1 << 14);
private int messageLength;
@Before
@SuppressWarnings("unchecked")
public void init() {
readConsumer = mock(ChannelContext.ReadConsumer.class);
messageLength = randomInt(96) + 20;
selector = mock(SocketSelector.class);
listener = mock(BiConsumer.class);
channel = mock(NioSocketChannel.class);
sslDriver = mock(SSLDriver.class);
Supplier<InboundChannelBuffer.Page> pageSupplier = () ->
new InboundChannelBuffer.Page(ByteBuffer.allocate(BigArrays.BYTE_PAGE_SIZE), () -> {
});
channelBuffer = new InboundChannelBuffer(pageSupplier);
context = new SSLChannelContext(channel, sslDriver, readConsumer, channelBuffer);
when(channel.getSelector()).thenReturn(selector);
when(selector.isOnCurrentThread()).thenReturn(true);
when(sslDriver.getNetworkReadBuffer()).thenReturn(readBuffer);
when(sslDriver.getNetworkWriteBuffer()).thenReturn(writeBuffer);
}
public void testSuccessfulRead() throws IOException {
byte[] bytes = createMessage(messageLength);
when(channel.read(same(readBuffer))).thenReturn(bytes.length);
doAnswer(getAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer);
when(readConsumer.consumeReads(channelBuffer)).thenReturn(messageLength, 0);
assertEquals(messageLength, context.read());
assertEquals(0, channelBuffer.getIndex());
assertEquals(BigArrays.BYTE_PAGE_SIZE - bytes.length, channelBuffer.getCapacity());
verify(readConsumer, times(1)).consumeReads(channelBuffer);
}
public void testMultipleReadsConsumed() throws IOException {
byte[] bytes = createMessage(messageLength * 2);
when(channel.read(same(readBuffer))).thenReturn(bytes.length);
doAnswer(getAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer);
when(readConsumer.consumeReads(channelBuffer)).thenReturn(messageLength, messageLength, 0);
assertEquals(bytes.length, context.read());
assertEquals(0, channelBuffer.getIndex());
assertEquals(BigArrays.BYTE_PAGE_SIZE - bytes.length, channelBuffer.getCapacity());
verify(readConsumer, times(2)).consumeReads(channelBuffer);
}
public void testPartialRead() throws IOException {
byte[] bytes = createMessage(messageLength);
when(channel.read(same(readBuffer))).thenReturn(bytes.length);
doAnswer(getAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer);
when(readConsumer.consumeReads(channelBuffer)).thenReturn(0);
assertEquals(messageLength, context.read());
assertEquals(bytes.length, channelBuffer.getIndex());
verify(readConsumer, times(1)).consumeReads(channelBuffer);
when(readConsumer.consumeReads(channelBuffer)).thenReturn(messageLength * 2, 0);
assertEquals(messageLength, context.read());
assertEquals(0, channelBuffer.getIndex());
assertEquals(BigArrays.BYTE_PAGE_SIZE - (bytes.length * 2), channelBuffer.getCapacity());
verify(readConsumer, times(2)).consumeReads(channelBuffer);
}
public void testReadThrowsIOException() throws IOException {
IOException ioException = new IOException();
when(channel.read(any(ByteBuffer.class))).thenThrow(ioException);
IOException ex = expectThrows(IOException.class, () -> context.read());
assertSame(ioException, ex);
}
public void testReadThrowsIOExceptionMeansReadyForClose() throws IOException {
when(channel.read(any(ByteBuffer.class))).thenThrow(new IOException());
assertFalse(context.selectorShouldClose());
expectThrows(IOException.class, () -> context.read());
assertTrue(context.selectorShouldClose());
}
public void testReadLessThanZeroMeansReadyForClose() throws IOException {
when(channel.read(any(ByteBuffer.class))).thenReturn(-1);
assertEquals(0, context.read());
assertTrue(context.selectorShouldClose());
}
public void testCloseClosesChannelBuffer() throws IOException {
Runnable closer = mock(Runnable.class);
Supplier<InboundChannelBuffer.Page> pageSupplier = () -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14), closer);
InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier);
buffer.ensureCapacity(1);
BytesChannelContext context = new BytesChannelContext(channel, readConsumer, buffer);
context.closeFromSelector();
verify(closer).run();
}
public void testWriteOpsClearedOnClose() throws IOException {
assertFalse(context.hasQueuedWriteOps());
ByteBuffer[] buffer = {ByteBuffer.allocate(10)};
context.queueWriteOperation(new BytesWriteOperation(channel, buffer, listener));
when(sslDriver.readyForApplicationWrites()).thenReturn(true);
assertTrue(context.hasQueuedWriteOps());
context.closeFromSelector();
verify(selector).executeFailedListener(same(listener), any(ClosedChannelException.class));
assertFalse(context.hasQueuedWriteOps());
}
public void testSSLDriverClosedOnClose() throws IOException {
context.closeFromSelector();
verify(sslDriver).close();
}
public void testWriteFailsIfClosing() {
context.closeChannel();
ByteBuffer[] buffers = {ByteBuffer.wrap(createMessage(10))};
context.sendMessage(buffers, listener);
verify(listener).accept(isNull(Void.class), any(ClosedChannelException.class));
}
public void testSendMessageFromDifferentThreadIsQueuedWithSelector() throws Exception {
ArgumentCaptor<BytesWriteOperation> writeOpCaptor = ArgumentCaptor.forClass(BytesWriteOperation.class);
when(selector.isOnCurrentThread()).thenReturn(false);
ByteBuffer[] buffers = {ByteBuffer.wrap(createMessage(10))};
context.sendMessage(buffers, listener);
verify(selector).queueWrite(writeOpCaptor.capture());
BytesWriteOperation writeOp = writeOpCaptor.getValue();
assertSame(listener, writeOp.getListener());
assertSame(channel, writeOp.getChannel());
assertEquals(buffers[0], writeOp.getBuffersToWrite()[0]);
}
public void testSendMessageFromSameThreadIsQueuedInChannel() {
ArgumentCaptor<BytesWriteOperation> writeOpCaptor = ArgumentCaptor.forClass(BytesWriteOperation.class);
ByteBuffer[] buffers = {ByteBuffer.wrap(createMessage(10))};
context.sendMessage(buffers, listener);
verify(selector).queueWriteInChannelBuffer(writeOpCaptor.capture());
BytesWriteOperation writeOp = writeOpCaptor.getValue();
assertSame(listener, writeOp.getListener());
assertSame(channel, writeOp.getChannel());
assertEquals(buffers[0], writeOp.getBuffersToWrite()[0]);
}
public void testWriteIsQueuedInChannel() {
when(sslDriver.readyForApplicationWrites()).thenReturn(true);
when(sslDriver.hasFlushPending()).thenReturn(false);
when(sslDriver.needsNonApplicationWrite()).thenReturn(false);
assertFalse(context.hasQueuedWriteOps());
ByteBuffer[] buffer = {ByteBuffer.allocate(10)};
context.queueWriteOperation(new BytesWriteOperation(channel, buffer, listener));
assertTrue(context.hasQueuedWriteOps());
}
public void testQueuedWritesAreIgnoredWhenNotReadyForAppWrites() {
when(sslDriver.readyForApplicationWrites()).thenReturn(false);
when(sslDriver.hasFlushPending()).thenReturn(false);
when(sslDriver.needsNonApplicationWrite()).thenReturn(false);
ByteBuffer[] buffer = {ByteBuffer.allocate(10)};
context.queueWriteOperation(new BytesWriteOperation(channel, buffer, listener));
assertFalse(context.hasQueuedWriteOps());
}
public void testPendingFlushMeansWriteInterested() {
when(sslDriver.readyForApplicationWrites()).thenReturn(randomBoolean());
when(sslDriver.hasFlushPending()).thenReturn(true);
when(sslDriver.needsNonApplicationWrite()).thenReturn(false);
assertTrue(context.hasQueuedWriteOps());
}
public void testNeedsNonAppWritesMeansWriteInterested() {
when(sslDriver.readyForApplicationWrites()).thenReturn(false);
when(sslDriver.hasFlushPending()).thenReturn(false);
when(sslDriver.needsNonApplicationWrite()).thenReturn(true);
assertTrue(context.hasQueuedWriteOps());
}
public void testNotWritesInterestInAppMode() {
when(sslDriver.readyForApplicationWrites()).thenReturn(true);
when(sslDriver.hasFlushPending()).thenReturn(false);
assertFalse(context.hasQueuedWriteOps());
verify(sslDriver, times(0)).needsNonApplicationWrite();
}
public void testFirstFlushMustFinishForWriteToContinue() throws Exception {
when(sslDriver.hasFlushPending()).thenReturn(true, true);
when(sslDriver.readyForApplicationWrites()).thenReturn(false);
context.flushChannel();
verify(sslDriver, times(0)).nonApplicationWrite();
}
public void testNonAppWrites() throws Exception {
when(sslDriver.hasFlushPending()).thenReturn(false, false, true, false, true);
when(sslDriver.needsNonApplicationWrite()).thenReturn(true, true, false);
when(sslDriver.readyForApplicationWrites()).thenReturn(false);
context.flushChannel();
verify(sslDriver, times(2)).nonApplicationWrite();
verify(channel, times(2)).write(sslDriver.getNetworkWriteBuffer());
}
public void testNonAppWritesStopIfBufferNotFullyFlushed() throws Exception {
when(sslDriver.hasFlushPending()).thenReturn(false, false, true, true);
when(sslDriver.needsNonApplicationWrite()).thenReturn(true, true, true, true);
when(sslDriver.readyForApplicationWrites()).thenReturn(false);
context.flushChannel();
verify(sslDriver, times(1)).nonApplicationWrite();
verify(channel, times(1)).write(sslDriver.getNetworkWriteBuffer());
}
public void testQueuedWriteIsFlushedInFlushCall() throws Exception {
ByteBuffer[] buffers = {ByteBuffer.allocate(10)};
BytesWriteOperation writeOperation = mock(BytesWriteOperation.class);
context.queueWriteOperation(writeOperation);
when(writeOperation.getBuffersToWrite()).thenReturn(buffers);
when(writeOperation.getListener()).thenReturn(listener);
when(sslDriver.hasFlushPending()).thenReturn(false, false, false, false);
when(sslDriver.readyForApplicationWrites()).thenReturn(true);
when(sslDriver.applicationWrite(buffers)).thenReturn(10);
when(writeOperation.isFullyFlushed()).thenReturn(false,true);
context.flushChannel();
verify(writeOperation).incrementIndex(10);
verify(channel, times(1)).write(sslDriver.getNetworkWriteBuffer());
verify(selector).executeListener(listener, null);
assertFalse(context.hasQueuedWriteOps());
}
public void testPartialFlush() throws IOException {
ByteBuffer[] buffers = {ByteBuffer.allocate(10)};
BytesWriteOperation writeOperation = mock(BytesWriteOperation.class);
context.queueWriteOperation(writeOperation);
when(writeOperation.getBuffersToWrite()).thenReturn(buffers);
when(writeOperation.getListener()).thenReturn(listener);
when(sslDriver.hasFlushPending()).thenReturn(false, false, true);
when(sslDriver.readyForApplicationWrites()).thenReturn(true);
when(sslDriver.applicationWrite(buffers)).thenReturn(5);
when(writeOperation.isFullyFlushed()).thenReturn(false, false);
context.flushChannel();
verify(writeOperation).incrementIndex(5);
verify(channel, times(1)).write(sslDriver.getNetworkWriteBuffer());
verify(selector, times(0)).executeListener(listener, null);
assertTrue(context.hasQueuedWriteOps());
}
@SuppressWarnings("unchecked")
public void testMultipleWritesPartialFlushes() throws IOException {
BiConsumer<Void, Throwable> listener2 = mock(BiConsumer.class);
ByteBuffer[] buffers1 = {ByteBuffer.allocate(10)};
ByteBuffer[] buffers2 = {ByteBuffer.allocate(5)};
BytesWriteOperation writeOperation1 = mock(BytesWriteOperation.class);
BytesWriteOperation writeOperation2 = mock(BytesWriteOperation.class);
when(writeOperation1.getBuffersToWrite()).thenReturn(buffers1);
when(writeOperation2.getBuffersToWrite()).thenReturn(buffers2);
when(writeOperation1.getListener()).thenReturn(listener);
when(writeOperation2.getListener()).thenReturn(listener2);
context.queueWriteOperation(writeOperation1);
context.queueWriteOperation(writeOperation2);
when(sslDriver.hasFlushPending()).thenReturn(false, false, false, false, false, true);
when(sslDriver.readyForApplicationWrites()).thenReturn(true);
when(sslDriver.applicationWrite(buffers1)).thenReturn(5, 5);
when(sslDriver.applicationWrite(buffers2)).thenReturn(3);
when(writeOperation1.isFullyFlushed()).thenReturn(false, false, true);
when(writeOperation2.isFullyFlushed()).thenReturn(false);
context.flushChannel();
verify(writeOperation1, times(2)).incrementIndex(5);
verify(channel, times(3)).write(sslDriver.getNetworkWriteBuffer());
verify(selector).executeListener(listener, null);
verify(selector, times(0)).executeListener(listener2, null);
assertTrue(context.hasQueuedWriteOps());
}
public void testWhenIOExceptionThrownListenerIsCalled() throws IOException {
ByteBuffer[] buffers = {ByteBuffer.allocate(10)};
BytesWriteOperation writeOperation = mock(BytesWriteOperation.class);
context.queueWriteOperation(writeOperation);
IOException exception = new IOException();
when(writeOperation.getBuffersToWrite()).thenReturn(buffers);
when(writeOperation.getListener()).thenReturn(listener);
when(sslDriver.hasFlushPending()).thenReturn(false, false);
when(sslDriver.readyForApplicationWrites()).thenReturn(true);
when(sslDriver.applicationWrite(buffers)).thenReturn(5);
when(channel.write(sslDriver.getNetworkWriteBuffer())).thenThrow(exception);
when(writeOperation.isFullyFlushed()).thenReturn(false);
expectThrows(IOException.class, () -> context.flushChannel());
verify(writeOperation).incrementIndex(5);
verify(selector).executeFailedListener(listener, exception);
assertFalse(context.hasQueuedWriteOps());
}
public void testWriteIOExceptionMeansChannelReadyToClose() throws Exception {
when(sslDriver.hasFlushPending()).thenReturn(true);
when(sslDriver.needsNonApplicationWrite()).thenReturn(true);
when(sslDriver.readyForApplicationWrites()).thenReturn(false);
when(channel.write(sslDriver.getNetworkWriteBuffer())).thenThrow(new IOException());
assertFalse(context.selectorShouldClose());
expectThrows(IOException.class, () -> context.flushChannel());
assertTrue(context.selectorShouldClose());
}
public void testReadyToCloseIfDriverIndicateClosed() {
when(sslDriver.isClosed()).thenReturn(false, true);
assertFalse(context.selectorShouldClose());
assertTrue(context.selectorShouldClose());
}
public void testInitiateCloseFromDifferentThreadSchedulesCloseNotify() {
when(selector.isOnCurrentThread()).thenReturn(false, true);
context.closeChannel();
ArgumentCaptor<WriteOperation> captor = ArgumentCaptor.forClass(WriteOperation.class);
verify(selector).queueWrite(captor.capture());
context.queueWriteOperation(captor.getValue());
verify(sslDriver).initiateClose();
}
public void testInitiateCloseFromSameThreadSchedulesCloseNotify() {
context.closeChannel();
ArgumentCaptor<WriteOperation> captor = ArgumentCaptor.forClass(WriteOperation.class);
verify(selector).queueWriteInChannelBuffer(captor.capture());
context.queueWriteOperation(captor.getValue());
verify(sslDriver).initiateClose();
}
public void testRegisterInitiatesDriver() throws IOException {
context.channelRegistered();
verify(sslDriver).init();
}
private Answer getAnswerForBytes(byte[] bytes) {
return invocationOnMock -> {
InboundChannelBuffer buffer = (InboundChannelBuffer) invocationOnMock.getArguments()[0];
buffer.ensureCapacity(buffer.getIndex() + bytes.length);
ByteBuffer[] buffers = buffer.sliceBuffersFrom(buffer.getIndex());
assert buffers[0].remaining() > bytes.length;
buffers[0].put(bytes);
buffer.incrementIndex(bytes.length);
return bytes.length;
};
}
private static byte[] createMessage(int length) {
byte[] bytes = new byte[length];
for (int i = 0; i < length; ++i) {
bytes[i] = randomByte();
}
return bytes;
}
}

View File

@ -0,0 +1,308 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.security.transport.nio;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.test.ESTestCase;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;
import javax.net.ssl.TrustManagerFactory;
import java.io.IOException;
import java.io.InputStream;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.security.KeyStore;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.function.Supplier;
public class SSLDriverTests extends ESTestCase {
private final Supplier<InboundChannelBuffer.Page> pageSupplier =
() -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14), () -> {});
private InboundChannelBuffer serverBuffer = new InboundChannelBuffer(pageSupplier);
private InboundChannelBuffer clientBuffer = new InboundChannelBuffer(pageSupplier);
private InboundChannelBuffer genericBuffer = new InboundChannelBuffer(pageSupplier);
public void testPingPongAndClose() throws Exception {
SSLContext sslContext = getSSLContext();
SSLDriver clientDriver = getDriver(sslContext.createSSLEngine(), true);
SSLDriver serverDriver = getDriver(sslContext.createSSLEngine(), false);
handshake(clientDriver, serverDriver);
ByteBuffer[] buffers = {ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8))};
sendAppData(clientDriver, serverDriver, buffers);
serverDriver.read(serverBuffer);
assertEquals(ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8)), serverBuffer.sliceBuffersTo(4)[0]);
ByteBuffer[] buffers2 = {ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8))};
sendAppData(serverDriver, clientDriver, buffers2);
clientDriver.read(clientBuffer);
assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), clientBuffer.sliceBuffersTo(4)[0]);
assertFalse(clientDriver.needsNonApplicationWrite());
normalClose(clientDriver, serverDriver);
}
public void testBigAppData() throws Exception {
SSLContext sslContext = getSSLContext();
SSLDriver clientDriver = getDriver(sslContext.createSSLEngine(), true);
SSLDriver serverDriver = getDriver(sslContext.createSSLEngine(), false);
handshake(clientDriver, serverDriver);
ByteBuffer buffer = ByteBuffer.allocate(1 << 15);
for (int i = 0; i < (1 << 15); ++i) {
buffer.put((byte) i);
}
ByteBuffer[] buffers = {buffer};
sendAppData(clientDriver, serverDriver, buffers);
serverDriver.read(serverBuffer);
assertEquals(16384, serverBuffer.sliceBuffersFrom(0)[0].limit());
assertEquals(16384, serverBuffer.sliceBuffersFrom(0)[1].limit());
ByteBuffer[] buffers2 = {ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8))};
sendAppData(serverDriver, clientDriver, buffers2);
clientDriver.read(clientBuffer);
assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), clientBuffer.sliceBuffersTo(4)[0]);
assertFalse(clientDriver.needsNonApplicationWrite());
normalClose(clientDriver, serverDriver);
}
public void testHandshakeFailureBecauseProtocolMismatch() throws Exception {
SSLContext sslContext = getSSLContext();
SSLEngine clientEngine = sslContext.createSSLEngine();
SSLEngine serverEngine = sslContext.createSSLEngine();
String[] serverProtocols = {"TLSv1.1", "TLSv1.2"};
serverEngine.setEnabledProtocols(serverProtocols);
String[] clientProtocols = {"TLSv1"};
clientEngine.setEnabledProtocols(clientProtocols);
SSLDriver clientDriver = getDriver(clientEngine, true);
SSLDriver serverDriver = getDriver(serverEngine, false);
SSLException sslException = expectThrows(SSLException.class, () -> handshake(clientDriver, serverDriver));
assertEquals("Client requested protocol TLSv1 not enabled or not supported", sslException.getMessage());
failedCloseAlert(serverDriver, clientDriver);
}
public void testHandshakeFailureBecauseNoCiphers() throws Exception {
SSLContext sslContext = getSSLContext();
SSLEngine clientEngine = sslContext.createSSLEngine();
SSLEngine serverEngine = sslContext.createSSLEngine();
String[] enabledCipherSuites = clientEngine.getEnabledCipherSuites();
int midpoint = enabledCipherSuites.length / 2;
String[] serverCiphers = Arrays.copyOfRange(enabledCipherSuites, 0, midpoint);
serverEngine.setEnabledCipherSuites(serverCiphers);
String[] clientCiphers = Arrays.copyOfRange(enabledCipherSuites, midpoint, enabledCipherSuites.length - 1);
clientEngine.setEnabledCipherSuites(clientCiphers);
SSLDriver clientDriver = getDriver(clientEngine, true);
SSLDriver serverDriver = getDriver(serverEngine, false);
SSLException sslException = expectThrows(SSLException.class, () -> handshake(clientDriver, serverDriver));
assertEquals("no cipher suites in common", sslException.getMessage());
failedCloseAlert(serverDriver, clientDriver);
}
public void testCloseDuringHandshake() throws Exception {
SSLContext sslContext = getSSLContext();
SSLDriver clientDriver = getDriver(sslContext.createSSLEngine(), true);
SSLDriver serverDriver = getDriver(sslContext.createSSLEngine(), false);
clientDriver.init();
serverDriver.init();
assertTrue(clientDriver.needsNonApplicationWrite());
assertFalse(serverDriver.needsNonApplicationWrite());
sendHandshakeMessages(clientDriver, serverDriver);
sendHandshakeMessages(serverDriver, clientDriver);
sendData(clientDriver, serverDriver);
assertTrue(clientDriver.isHandshaking());
assertTrue(serverDriver.isHandshaking());
assertFalse(serverDriver.needsNonApplicationWrite());
serverDriver.initiateClose();
assertTrue(serverDriver.needsNonApplicationWrite());
assertFalse(serverDriver.isClosed());
sendNeededWrites(serverDriver, clientDriver);
// We are immediately fully closed due to SSLEngine inconsistency
assertTrue(serverDriver.isClosed());
// This should not throw exception yet as the SSLEngine will not UNWRAP data while attempting to WRAP
clientDriver.read(clientBuffer);
sendNeededWrites(clientDriver, serverDriver);
SSLException sslException = expectThrows(SSLException.class, () -> clientDriver.read(clientBuffer));
assertEquals("Received close_notify during handshake", sslException.getMessage());
assertTrue(clientDriver.needsNonApplicationWrite());
sendNeededWrites(clientDriver, serverDriver);
serverDriver.read(serverBuffer);
assertTrue(clientDriver.isClosed());
}
private void failedCloseAlert(SSLDriver sendDriver, SSLDriver receiveDriver) throws SSLException {
assertTrue(sendDriver.needsNonApplicationWrite());
assertFalse(sendDriver.isClosed());
sendNeededWrites(sendDriver, receiveDriver);
assertTrue(sendDriver.isClosed());
sendDriver.close();
SSLException sslException = expectThrows(SSLException.class, () -> receiveDriver.read(genericBuffer));
assertEquals("Received fatal alert: handshake_failure", sslException.getMessage());
if (receiveDriver.needsNonApplicationWrite() == false) {
assertTrue(receiveDriver.isClosed());
receiveDriver.close();
} else {
assertFalse(receiveDriver.isClosed());
expectThrows(SSLException.class, receiveDriver::close);
}
}
private SSLContext getSSLContext() throws Exception {
String relativePath = "/org/elasticsearch/xpack/security/transport/ssl/certs/simple/testclient.jks";
SSLContext sslContext;
try (InputStream in = Files.newInputStream(getDataPath(relativePath))) {
KeyStore keyStore = KeyStore.getInstance("jks");
keyStore.load(in, "testclient".toCharArray());
TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
tmf.init(keyStore);
KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
kmf.init(keyStore, "testclient".toCharArray());
sslContext = SSLContext.getInstance("TLSv1.2");
sslContext.init(kmf.getKeyManagers(), tmf.getTrustManagers(), new SecureRandom());
return sslContext;
}
}
private void normalClose(SSLDriver sendDriver, SSLDriver receiveDriver) throws IOException {
sendDriver.initiateClose();
assertFalse(sendDriver.readyForApplicationWrites());
assertTrue(sendDriver.needsNonApplicationWrite());
sendNeededWrites(sendDriver, receiveDriver);
assertFalse(sendDriver.isClosed());
receiveDriver.read(genericBuffer);
assertFalse(receiveDriver.isClosed());
assertFalse(receiveDriver.readyForApplicationWrites());
assertTrue(receiveDriver.needsNonApplicationWrite());
sendNeededWrites(receiveDriver, sendDriver);
assertTrue(receiveDriver.isClosed());
sendDriver.read(genericBuffer);
assertTrue(sendDriver.isClosed());
sendDriver.close();
receiveDriver.close();
}
private void sendNeededWrites(SSLDriver sendDriver, SSLDriver receiveDriver) throws SSLException {
while (sendDriver.needsNonApplicationWrite() || sendDriver.hasFlushPending()) {
if (sendDriver.hasFlushPending() == false) {
sendDriver.nonApplicationWrite();
}
if (sendDriver.hasFlushPending()) {
sendData(sendDriver, receiveDriver, true);
}
}
}
private void handshake(SSLDriver clientDriver, SSLDriver serverDriver) throws IOException {
clientDriver.init();
serverDriver.init();
assertTrue(clientDriver.needsNonApplicationWrite());
assertFalse(serverDriver.needsNonApplicationWrite());
sendHandshakeMessages(clientDriver, serverDriver);
assertTrue(clientDriver.isHandshaking());
assertTrue(serverDriver.isHandshaking());
sendHandshakeMessages(serverDriver, clientDriver);
assertTrue(clientDriver.isHandshaking());
assertTrue(serverDriver.isHandshaking());
sendHandshakeMessages(clientDriver, serverDriver);
assertTrue(clientDriver.isHandshaking());
assertTrue(serverDriver.isHandshaking());
sendHandshakeMessages(serverDriver, clientDriver);
assertFalse(clientDriver.isHandshaking());
assertFalse(serverDriver.isHandshaking());
}
private void sendHandshakeMessages(SSLDriver sendDriver, SSLDriver receiveDriver) throws IOException {
assertTrue(sendDriver.needsNonApplicationWrite() || sendDriver.hasFlushPending());
while (sendDriver.needsNonApplicationWrite() || sendDriver.hasFlushPending()) {
assertFalse(receiveDriver.needsNonApplicationWrite());
if (sendDriver.hasFlushPending() == false) {
sendDriver.nonApplicationWrite();
}
if (sendDriver.isHandshaking()) {
assertTrue(sendDriver.hasFlushPending());
sendData(sendDriver, receiveDriver);
assertFalse(sendDriver.hasFlushPending());
receiveDriver.read(genericBuffer);
}
}
if (receiveDriver.isHandshaking()) {
assertTrue(receiveDriver.needsNonApplicationWrite() || receiveDriver.hasFlushPending());
}
}
private void sendAppData(SSLDriver sendDriver, SSLDriver receiveDriver, ByteBuffer[] message) throws IOException {
assertFalse(sendDriver.needsNonApplicationWrite());
int bytesToEncrypt = Arrays.stream(message).mapToInt(Buffer::remaining).sum();
int bytesEncrypted = 0;
while (bytesToEncrypt > bytesEncrypted) {
bytesEncrypted += sendDriver.applicationWrite(message);
sendData(sendDriver, receiveDriver);
}
}
private void sendData(SSLDriver sendDriver, SSLDriver receiveDriver) {
sendData(sendDriver, receiveDriver, randomBoolean());
}
private void sendData(SSLDriver sendDriver, SSLDriver receiveDriver, boolean partial) {
ByteBuffer writeBuffer = sendDriver.getNetworkWriteBuffer();
ByteBuffer readBuffer = receiveDriver.getNetworkReadBuffer();
if (partial) {
int initialLimit = writeBuffer.limit();
int bytesToWrite = writeBuffer.remaining() / (randomInt(2) + 2);
writeBuffer.limit(writeBuffer.position() + bytesToWrite);
readBuffer.put(writeBuffer);
writeBuffer.limit(initialLimit);
assertTrue(sendDriver.hasFlushPending());
readBuffer.put(writeBuffer);
assertFalse(sendDriver.hasFlushPending());
} else {
readBuffer.put(writeBuffer);
assertFalse(sendDriver.hasFlushPending());
}
}
private SSLDriver getDriver(SSLEngine engine, boolean isClient) {
return new SSLDriver(engine, isClient);
}
}

View File

@ -0,0 +1,167 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.security.transport.nio;
import org.elasticsearch.Version;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.MockSecureSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.MockPageCacheRecycler;
import org.elasticsearch.env.TestEnvironment;
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.node.Node;
import org.elasticsearch.test.transport.MockTransportService;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.AbstractSimpleTransportTestCase;
import org.elasticsearch.transport.BindTransportException;
import org.elasticsearch.transport.ConnectTransportException;
import org.elasticsearch.transport.TcpChannel;
import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.ssl.SSLService;
import java.io.IOException;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.nio.file.Path;
import java.util.Collections;
import static java.util.Collections.emptyMap;
import static java.util.Collections.emptySet;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.instanceOf;
public class SimpleSecurityNioTransportTests extends AbstractSimpleTransportTestCase {
private SSLService createSSLService() {
Path testnodeStore = getDataPath("/org/elasticsearch/xpack/security/transport/ssl/certs/simple/testnode.jks");
MockSecureSettings secureSettings = new MockSecureSettings();
secureSettings.setString("xpack.ssl.keystore.secure_password", "testnode");
Settings settings = Settings.builder()
.put("xpack.security.transport.ssl.enabled", true)
.put("xpack.ssl.keystore.path", testnodeStore)
.setSecureSettings(secureSettings)
.put("path.home", createTempDir())
.build();
try {
return new SSLService(settings, TestEnvironment.newEnvironment(settings));
} catch (Exception e) {
throw new RuntimeException(e);
}
}
public MockTransportService nioFromThreadPool(Settings settings, ThreadPool threadPool, final Version version,
ClusterSettings clusterSettings, boolean doHandshake) {
NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList());
NetworkService networkService = new NetworkService(Collections.emptyList());
Settings settings1 = Settings.builder()
.put(settings)
.put("xpack.security.transport.ssl.enabled", true).build();
Transport transport = new SecurityNioTransport(settings1, threadPool,
networkService, BigArrays.NON_RECYCLING_INSTANCE, new MockPageCacheRecycler(settings), namedWriteableRegistry,
new NoneCircuitBreakerService(), createSSLService()) {
@Override
protected Version executeHandshake(DiscoveryNode node, TcpChannel channel, TimeValue timeout) throws IOException,
InterruptedException {
if (doHandshake) {
return super.executeHandshake(node, channel, timeout);
} else {
return version.minimumCompatibilityVersion();
}
}
@Override
protected Version getCurrentVersion() {
return version;
}
};
MockTransportService mockTransportService =
MockTransportService.createNewService(Settings.EMPTY, transport, version, threadPool, clusterSettings,
Collections.emptySet());
mockTransportService.start();
return mockTransportService;
}
@Override
protected MockTransportService build(Settings settings, Version version, ClusterSettings clusterSettings, boolean doHandshake) {
settings = Settings.builder().put(settings)
.put(TcpTransport.PORT.getKey(), "0")
.build();
MockTransportService transportService = nioFromThreadPool(settings, threadPool, version, clusterSettings, doHandshake);
transportService.start();
return transportService;
}
@Override
protected void closeConnectionChannel(Transport transport, Transport.Connection connection) throws IOException {
@SuppressWarnings("unchecked")
TcpTransport.NodeChannels channels = (TcpTransport.NodeChannels) connection;
TcpChannel.closeChannels(channels.getChannels().subList(0, randomIntBetween(1, channels.getChannels().size())), true);
}
public void testConnectException() throws UnknownHostException {
try {
serviceA.connectToNode(new DiscoveryNode("C", new TransportAddress(InetAddress.getByName("localhost"), 9876),
emptyMap(), emptySet(),Version.CURRENT));
fail("Expected ConnectTransportException");
} catch (ConnectTransportException e) {
assertThat(e.getMessage(), containsString("connect_exception"));
assertThat(e.getMessage(), containsString("[127.0.0.1:9876]"));
Throwable cause = e.getCause();
assertThat(cause, instanceOf(IOException.class));
}
}
public void testBindUnavailableAddress() {
// this is on a lower level since it needs access to the TransportService before it's started
int port = serviceA.boundAddress().publishAddress().getPort();
Settings settings = Settings.builder()
.put(Node.NODE_NAME_SETTING.getKey(), "foobar")
.put(TransportService.TRACE_LOG_INCLUDE_SETTING.getKey(), "")
.put(TransportService.TRACE_LOG_EXCLUDE_SETTING.getKey(), "NOTHING")
.put("transport.tcp.port", port)
.build();
ClusterSettings clusterSettings = new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
BindTransportException bindTransportException = expectThrows(BindTransportException.class, () -> {
MockTransportService transportService = nioFromThreadPool(settings, threadPool, Version.CURRENT, clusterSettings, true);
try {
transportService.start();
} finally {
transportService.stop();
transportService.close();
}
});
assertEquals("Failed to bind to ["+ port + "]", bindTransportException.getMessage());
}
// TODO: These tests currently rely on plaintext transports
@Override
@AwaitsFix(bugUrl = "")
public void testTcpHandshake() throws IOException, InterruptedException {
}
@Override
@AwaitsFix(bugUrl = "")
public void testHandshakeWithIncompatVersion() {
}
@Override
@AwaitsFix(bugUrl = "")
public void testHandshakeUpdatesVersion() throws IOException {
}
}