diff --git a/nifi-commons/nifi-security-socket-ssl/pom.xml b/nifi-commons/nifi-security-socket-ssl/pom.xml
index c4edb99be7..b4592a5864 100644
--- a/nifi-commons/nifi-security-socket-ssl/pom.xml
+++ b/nifi-commons/nifi-security-socket-ssl/pom.xml
@@ -31,5 +31,17 @@
org.slf4j
slf4j-api
+
+ org.apache.nifi
+ nifi-security-utils
+ 1.14.0-SNAPSHOT
+ test
+
+
+ io.netty
+ netty-handler
+ 4.1.65.Final
+ test
+
diff --git a/nifi-commons/nifi-security-socket-ssl/src/main/java/org/apache/nifi/remote/io/socket/ssl/SSLSocketChannel.java b/nifi-commons/nifi-security-socket-ssl/src/main/java/org/apache/nifi/remote/io/socket/ssl/SSLSocketChannel.java
index 59902d05f5..9a5cdd8b50 100644
--- a/nifi-commons/nifi-security-socket-ssl/src/main/java/org/apache/nifi/remote/io/socket/ssl/SSLSocketChannel.java
+++ b/nifi-commons/nifi-security-socket-ssl/src/main/java/org/apache/nifi/remote/io/socket/ssl/SSLSocketChannel.java
@@ -26,8 +26,9 @@ import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.Status;
+import javax.net.ssl.SSLException;
import javax.net.ssl.SSLHandshakeException;
-import javax.net.ssl.SSLPeerUnverifiedException;
+import javax.net.ssl.SSLSession;
import java.io.Closeable;
import java.io.IOException;
import java.net.InetAddress;
@@ -38,89 +39,87 @@ import java.net.SocketTimeoutException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedByInterruptException;
import java.nio.channels.SocketChannel;
-import java.security.cert.Certificate;
-import java.security.cert.CertificateException;
-import java.security.cert.X509Certificate;
import java.util.concurrent.TimeUnit;
+/**
+ * SSLSocketChannel supports reading and writing bytes using TLS and NIO SocketChannels with configurable timeouts
+ */
public class SSLSocketChannel implements Closeable {
+ private static final Logger LOGGER = LoggerFactory.getLogger(SSLSocketChannel.class);
- public static final int MAX_WRITE_SIZE = 65536;
-
- private static final Logger logger = LoggerFactory.getLogger(SSLSocketChannel.class);
+ private static final int DISCARD_BUFFER_LENGTH = 8192;
+ private static final int END_OF_STREAM = -1;
+ private static final byte[] EMPTY_MESSAGE = new byte[0];
private static final long BUFFER_FULL_EMPTY_WAIT_NANOS = TimeUnit.NANOSECONDS.convert(1, TimeUnit.MILLISECONDS);
+ private static final long FINISH_CONNECT_SLEEP = 50;
+ private static final long INITIAL_INCREMENTAL_SLEEP = 1;
+ private static final boolean CLIENT_AUTHENTICATION_REQUIRED = true;
private final String remoteAddress;
private final int port;
private final SSLEngine engine;
private final SocketAddress socketAddress;
-
- private BufferStateManager streamInManager;
- private BufferStateManager streamOutManager;
- private BufferStateManager appDataManager;
-
- private SocketChannel channel;
-
- private final byte[] oneByteBuffer = new byte[1];
-
+ private final BufferStateManager streamInManager;
+ private final BufferStateManager streamOutManager;
+ private final BufferStateManager appDataManager;
+ private final SocketChannel channel;
private int timeoutMillis = 30000;
- private volatile boolean connected = false;
- private boolean handshaking = false;
- private boolean closed = false;
+
private volatile boolean interrupted = false;
+ private volatile ChannelStatus channelStatus = ChannelStatus.DISCONNECTED;
- public SSLSocketChannel(final SSLContext sslContext, final String hostname, final int port, final InetAddress localAddress, final boolean client) throws IOException {
- this.socketAddress = new InetSocketAddress(hostname, port);
- this.channel = SocketChannel.open();
- if (localAddress != null) {
- final SocketAddress localSocketAddress = new InetSocketAddress(localAddress, 0);
- this.channel.bind(localSocketAddress);
- }
- this.remoteAddress = hostname;
+ /**
+ * SSLSocketChannel constructor with SSLContext and remote address parameters
+ *
+ * @param sslContext SSLContext used to create SSLEngine with specified client mode
+ * @param remoteAddress Remote Address used for connection
+ * @param port Remote Port used for connection
+ * @param bindAddress Local address used for binding server channel when provided
+ * @param useClientMode Use Client Mode
+ * @throws IOException Thrown on failures creating Socket Channel
+ */
+ public SSLSocketChannel(final SSLContext sslContext, final String remoteAddress, final int port, final InetAddress bindAddress, final boolean useClientMode) throws IOException {
+ this.engine = createEngine(sslContext, useClientMode);
+ this.channel = createSocketChannel(bindAddress);
+ this.socketAddress = new InetSocketAddress(remoteAddress, port);
+ this.remoteAddress = remoteAddress;
this.port = port;
- this.engine = sslContext.createSSLEngine();
- this.engine.setUseClientMode(client);
- engine.setNeedClientAuth(true);
streamInManager = new BufferStateManager(ByteBuffer.allocate(engine.getSession().getPacketBufferSize()));
streamOutManager = new BufferStateManager(ByteBuffer.allocate(engine.getSession().getPacketBufferSize()));
appDataManager = new BufferStateManager(ByteBuffer.allocate(engine.getSession().getApplicationBufferSize()));
}
- public SSLSocketChannel(final SSLContext sslContext, final SocketChannel socketChannel, final boolean client) throws IOException {
- if (!socketChannel.isConnected()) {
- throw new IllegalArgumentException("Cannot pass an un-connected SocketChannel");
- }
-
- this.channel = socketChannel;
-
- this.socketAddress = socketChannel.getRemoteAddress();
- final Socket socket = socketChannel.socket();
- this.remoteAddress = socket.getInetAddress().toString();
- this.port = socket.getPort();
-
- this.engine = sslContext.createSSLEngine();
- this.engine.setUseClientMode(client);
- this.engine.setNeedClientAuth(true);
-
- streamInManager = new BufferStateManager(ByteBuffer.allocate(engine.getSession().getPacketBufferSize()));
- streamOutManager = new BufferStateManager(ByteBuffer.allocate(engine.getSession().getPacketBufferSize()));
- appDataManager = new BufferStateManager(ByteBuffer.allocate(engine.getSession().getApplicationBufferSize()));
+ /**
+ * SSLSocketChannel constructor with SSLContext and connected SocketChannel
+ *
+ * @param sslContext SSLContext used to create SSLEngine with specified client mode
+ * @param socketChannel Connected SocketChannel
+ * @param useClientMode Use Client Mode
+ * @throws IOException Thrown on SocketChannel.getRemoteAddress()
+ */
+ public SSLSocketChannel(final SSLContext sslContext, final SocketChannel socketChannel, final boolean useClientMode) throws IOException {
+ this(createEngine(sslContext, useClientMode), socketChannel);
}
+ /**
+ * SSLSocketChannel constructor with configured SSLEngine and connected SocketChannel
+ *
+ * @param sslEngine SSLEngine configured with mode and client authentication
+ * @param socketChannel Connected SocketChannel
+ * @throws IOException Thrown on SocketChannel.getRemoteAddress()
+ */
public SSLSocketChannel(final SSLEngine sslEngine, final SocketChannel socketChannel) throws IOException {
if (!socketChannel.isConnected()) {
- throw new IllegalArgumentException("Cannot pass an un-connected SocketChannel");
+ throw new IllegalArgumentException("Connected SocketChannel required");
}
+ socketChannel.configureBlocking(false);
this.channel = socketChannel;
-
this.socketAddress = socketChannel.getRemoteAddress();
final Socket socket = socketChannel.socket();
this.remoteAddress = socket.getInetAddress().toString();
this.port = socket.getPort();
-
- // don't set useClientMode or needClientAuth, use the engine as is and let the caller configure it
this.engine = sslEngine;
streamInManager = new BufferStateManager(ByteBuffer.allocate(engine.getSession().getPacketBufferSize()));
@@ -128,166 +127,64 @@ public class SSLSocketChannel implements Closeable {
appDataManager = new BufferStateManager(ByteBuffer.allocate(engine.getSession().getApplicationBufferSize()));
}
- public void setTimeout(final int millis) {
- this.timeoutMillis = millis;
+ public void setTimeout(final int timeoutMillis) {
+ this.timeoutMillis = timeoutMillis;
}
public int getTimeout() {
return timeoutMillis;
}
+ /**
+ * Connect Channel when not connected and perform handshake process
+ *
+ * @throws IOException Thrown on connection failures
+ */
public void connect() throws IOException {
+ channelStatus = ChannelStatus.CONNECTING;
+
try {
- channel.configureBlocking(false);
if (!channel.isConnected()) {
- final long startTime = System.currentTimeMillis();
+ logOperation("Connection Started");
+ final long started = System.currentTimeMillis();
if (!channel.connect(socketAddress)) {
while (!channel.finishConnect()) {
- if (interrupted) {
- throw new TransmissionDisabledException();
- }
- if (System.currentTimeMillis() > startTime + timeoutMillis) {
- throw new SocketTimeoutException("Timed out connecting to " + remoteAddress + ":" + port);
- }
+ checkInterrupted();
+ checkTimeoutExceeded(started);
try {
- Thread.sleep(50L);
+ TimeUnit.MILLISECONDS.sleep(FINISH_CONNECT_SLEEP);
} catch (final InterruptedException e) {
+ logOperation("Connection Interrupted");
}
}
}
}
- engine.beginHandshake();
+ channelStatus = ChannelStatus.CONNECTED;
+ } catch (final Exception e) {
+ close();
+ throw new SSLException(String.format("[%s:%d] Connection Failed", remoteAddress, port), e);
+ }
+ try {
performHandshake();
- logger.debug("{} Successfully completed SSL handshake", this);
-
- streamInManager.clear();
- streamOutManager.clear();
- appDataManager.clear();
-
- connected = true;
- } catch (final Exception e) {
- logger.error("{} failed to connect", this, e);
- closeQuietly(channel);
- engine.closeInbound();
- engine.closeOutbound();
- throw e;
- }
- }
-
- public String getDn() throws CertificateException, SSLPeerUnverifiedException {
- final Certificate[] certs = engine.getSession().getPeerCertificates();
- if (certs == null || certs.length == 0) {
- throw new SSLPeerUnverifiedException("No certificates found");
- }
-
- final Certificate certificate = certs[0];
- if (certificate instanceof X509Certificate) {
- final X509Certificate peerCertificate = (X509Certificate) certificate;
- peerCertificate.checkValidity();
- return peerCertificate.getSubjectDN().getName().trim();
- } else {
- throw new CertificateException(String.format("X.509 Certificate class not found [%s]", certificate.getClass()));
- }
- }
-
- private void performHandshake() throws IOException {
- // Generate handshake message
- final byte[] emptyMessage = new byte[0];
- handshaking = true;
- logger.debug("{} Performing Handshake", this);
-
- try {
- while (true) {
- switch (engine.getHandshakeStatus()) {
- case FINISHED:
- return;
- case NEED_WRAP: {
- final ByteBuffer appDataOut = ByteBuffer.wrap(emptyMessage);
-
- final ByteBuffer outboundBuffer = streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
-
- final SSLEngineResult wrapHelloResult = engine.wrap(appDataOut, outboundBuffer);
- if (wrapHelloResult.getStatus() == Status.BUFFER_OVERFLOW) {
- streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
- continue;
- }
-
- if (wrapHelloResult.getStatus() != Status.OK) {
- throw new SSLHandshakeException("Could not generate SSL Handshake information: SSLEngineResult: "
- + wrapHelloResult.toString());
- }
-
- logger.trace("{} Handshake response after wrapping: {}", this, wrapHelloResult);
-
- final ByteBuffer readableStreamOut = streamOutManager.prepareForRead(1);
- final int bytesToSend = readableStreamOut.remaining();
- writeFully(readableStreamOut);
- logger.trace("{} Sent {} bytes of wrapped data for handshake", this, bytesToSend);
-
- streamOutManager.clear();
- }
- continue;
- case NEED_UNWRAP: {
- final ByteBuffer readableDataIn = streamInManager.prepareForRead(0);
- final ByteBuffer appData = appDataManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
-
- // Read handshake response from other side
- logger.trace("{} Unwrapping: {} to {}", this, readableDataIn, appData);
- SSLEngineResult handshakeResponseResult = engine.unwrap(readableDataIn, appData);
- logger.trace("{} Handshake response after unwrapping: {}", this, handshakeResponseResult);
-
- if (handshakeResponseResult.getStatus() == Status.BUFFER_UNDERFLOW) {
- final ByteBuffer writableDataIn = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
- final int bytesRead = readData(writableDataIn);
- if (bytesRead > 0) {
- logger.trace("{} Read {} bytes for handshake", this, bytesRead);
- }
-
- if (bytesRead < 0) {
- throw new SSLHandshakeException("Reached End-of-File marker while performing handshake");
- }
- } else if (handshakeResponseResult.getStatus() == Status.CLOSED) {
- throw new IOException("Channel was closed by peer during handshake");
- } else {
- streamInManager.compact();
- appDataManager.clear();
- }
- }
- break;
- case NEED_TASK:
- performTasks();
- continue;
- case NOT_HANDSHAKING:
- return;
- }
- }
- } finally {
- handshaking = false;
- }
- }
-
- private void performTasks() {
- Runnable runnable;
- while ((runnable = engine.getDelegatedTask()) != null) {
- runnable.run();
- }
- }
-
- private void closeQuietly(final Closeable closeable) {
- try {
- closeable.close();
- } catch (final Exception e) {
+ } catch (final IOException e) {
+ close();
+ throw new SSLException(String.format("[%s:%d] Handshake Failed", remoteAddress, port), e);
}
}
+ /**
+ * Shutdown Socket Channel input and read available bytes
+ *
+ * @throws IOException Thrown on Socket Channel failures
+ */
public void consume() throws IOException {
channel.shutdownInput();
- final byte[] b = new byte[4096];
- final ByteBuffer buffer = ByteBuffer.wrap(b);
+ final byte[] byteBuffer = new byte[DISCARD_BUFFER_LENGTH];
+ final ByteBuffer buffer = ByteBuffer.wrap(byteBuffer);
int readCount;
do {
readCount = channel.read(buffer);
@@ -295,209 +192,104 @@ public class SSLSocketChannel implements Closeable {
} while (readCount > 0);
}
- private int readData(final ByteBuffer dest) throws IOException {
- final long startTime = System.currentTimeMillis();
-
- while (true) {
- if (interrupted) {
- throw new TransmissionDisabledException();
- }
-
- if (dest.remaining() == 0) {
- return 0;
- }
-
- final int readCount = channel.read(dest);
-
- long sleepNanos = 1L;
- if (readCount == 0) {
- if (System.currentTimeMillis() > startTime + timeoutMillis) {
- throw new SocketTimeoutException("Timed out reading from socket connected to " + remoteAddress + ":" + port);
- }
- try {
- TimeUnit.NANOSECONDS.sleep(sleepNanos);
- } catch (InterruptedException e) {
- close();
- Thread.currentThread().interrupt(); // set the interrupt status
- throw new ClosedByInterruptException();
- }
-
- sleepNanos = Math.min(sleepNanos * 2, BUFFER_FULL_EMPTY_WAIT_NANOS);
-
- continue;
- }
-
- logger.trace("{} Read {} bytes", this, readCount);
- return readCount;
- }
- }
-
- private Status encryptAndWriteFully(final BufferStateManager src) throws IOException {
- SSLEngineResult result = null;
-
- final ByteBuffer buff = src.prepareForRead(0);
- final ByteBuffer outBuff = streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
-
- logger.trace("{} Encrypting {} bytes", this, buff.remaining());
- while (buff.remaining() > 0) {
- result = engine.wrap(buff, outBuff);
- if (result.getStatus() == Status.OK) {
- final ByteBuffer readableOutBuff = streamOutManager.prepareForRead(0);
- writeFully(readableOutBuff);
- streamOutManager.clear();
- } else {
- return result.getStatus();
- }
- }
-
- return result.getStatus();
- }
-
- private void writeFully(final ByteBuffer src) throws IOException {
- long lastByteWrittenTime = System.currentTimeMillis();
-
- int bytesWritten = 0;
- while (src.hasRemaining()) {
- if (interrupted) {
- throw new TransmissionDisabledException();
- }
-
- final int written = channel.write(src);
- bytesWritten += written;
- final long now = System.currentTimeMillis();
- long sleepNanos = 1L;
-
- if (written > 0) {
- lastByteWrittenTime = now;
- } else {
- if (now > lastByteWrittenTime + timeoutMillis) {
- throw new SocketTimeoutException("Timed out writing to socket connected to " + remoteAddress + ":" + port);
- }
- try {
- TimeUnit.NANOSECONDS.sleep(sleepNanos);
- } catch (final InterruptedException e) {
- close();
- Thread.currentThread().interrupt(); // set the interrupt status
- throw new ClosedByInterruptException();
- }
-
- sleepNanos = Math.min(sleepNanos * 2, BUFFER_FULL_EMPTY_WAIT_NANOS);
- }
- }
-
- logger.trace("{} Wrote {} bytes", this, bytesWritten);
- }
-
+ /**
+ * Is Channel Closed
+ *
+ * @return Channel Closed Status
+ */
public boolean isClosed() {
- if (closed) {
+ if (ChannelStatus.CLOSED == channelStatus) {
return true;
}
- // need to detect if peer has sent closure handshake...if so the answer is true
- final ByteBuffer writableInBuffer = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
- int readCount = 0;
- try {
- readCount = channel.read(writableInBuffer);
- } catch (IOException e) {
- logger.error("{} failed to read data", this, e);
- readCount = -1; // treat the condition same as if End of Stream
- }
- if (readCount == 0) {
- return false;
- }
- if (readCount > 0) {
- logger.trace("{} Read {} bytes", this, readCount);
- final ByteBuffer streamInBuffer = streamInManager.prepareForRead(1);
- final ByteBuffer appDataBuffer = appDataManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
+ // Read Channel to determine closed status
+ final ByteBuffer inputBuffer = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
+ int bytesRead;
+ try {
+ bytesRead = channel.read(inputBuffer);
+ } catch (final IOException e) {
+ LOGGER.warn("[{}:{}] Closed Status Read Failed", remoteAddress, port, e);
+ bytesRead = END_OF_STREAM;
+ }
+ logOperationBytes("Closed Status Read", bytesRead);
+
+ if (bytesRead == 0) {
+ return false;
+ } else if (bytesRead > 0) {
try {
- SSLEngineResult unwrapResponse = engine.unwrap(streamInBuffer, appDataBuffer);
- logger.trace("{} When checking if closed, (handshake={}) Unwrap response: {}", this, handshaking, unwrapResponse);
- if (unwrapResponse.getStatus().equals(Status.CLOSED)) {
- // Drain the incoming TCP buffer
- final ByteBuffer discardBuffer = ByteBuffer.allocate(8192);
- int bytesDiscarded = channel.read(discardBuffer);
- while (bytesDiscarded > 0) {
- discardBuffer.clear();
- bytesDiscarded = channel.read(discardBuffer);
- }
+ final SSLEngineResult unwrapResult = unwrap();
+ if (Status.CLOSED == unwrapResult.getStatus()) {
+ readChannelDiscard();
engine.closeInbound();
} else {
streamInManager.compact();
return false;
}
- } catch (IOException e) {
- logger.error("{} failed to check if closed. Closing channel.", this, e);
+ } catch (final IOException e) {
+ LOGGER.warn("[{}:{}] Closed Status Unwrap Failed", remoteAddress, port, e);
}
}
- // either readCount is -1, indicating an end of stream, or the peer sent a closure handshake
- // so go ahead and close down the channel
- closeQuietly(channel.socket());
- closeQuietly(channel);
- closed = true;
+
+ // Close Channel when encountering end of stream or closed status
+ try {
+ close();
+ } catch (final IOException e) {
+ LOGGER.warn("[{}:{}] Close Failed", remoteAddress, port, e);
+ }
return true;
}
+ /**
+ * Close Channel and process notifications
+ *
+ * @throws IOException Thrown on SSLEngine.wrap() failures
+ */
@Override
public void close() throws IOException {
- logger.debug("{} Closing Connection", this);
- if (channel == null) {
- return;
- }
-
- if (closed) {
+ logOperation("Close Requested");
+ if (channelStatus == ChannelStatus.CLOSED) {
return;
}
try {
engine.closeOutbound();
- final byte[] emptyMessage = new byte[0];
-
- final ByteBuffer appDataOut = ByteBuffer.wrap(emptyMessage);
- final ByteBuffer outboundBuffer = streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
- final SSLEngineResult handshakeResult = engine.wrap(appDataOut, outboundBuffer);
-
- if (handshakeResult.getStatus() != Status.CLOSED) {
- throw new IOException("Invalid close state - will not send network data");
+ streamOutManager.clear();
+ final ByteBuffer inputBuffer = ByteBuffer.wrap(EMPTY_MESSAGE);
+ final ByteBuffer outputBuffer = streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
+ SSLEngineResult wrapResult = wrap(inputBuffer, outputBuffer);
+ Status status = wrapResult.getStatus();
+ if (Status.OK == status) {
+ logOperation("Clearing Outbound Buffer");
+ outputBuffer.clear();
+ wrapResult = wrap(inputBuffer, outputBuffer);
+ status = wrapResult.getStatus();
}
-
- final ByteBuffer readableStreamOut = streamOutManager.prepareForRead(1);
- writeFully(readableStreamOut);
- } finally {
- // Drain the incoming TCP buffer
- final ByteBuffer discardBuffer = ByteBuffer.allocate(8192);
- try {
- int bytesDiscarded = channel.read(discardBuffer);
- while (bytesDiscarded > 0) {
- discardBuffer.clear();
- bytesDiscarded = channel.read(discardBuffer);
+ if (Status.CLOSED == status) {
+ final ByteBuffer streamOutputBuffer = streamOutManager.prepareForRead(1);
+ try {
+ writeChannel(streamOutputBuffer);
+ } catch (final IOException e) {
+ logOperation(String.format("Write Close Notification Failed: %s", e.getMessage()));
}
- } catch (Exception e) {
+ } else {
+ throw new SSLException(String.format("[%s:%d] Invalid Wrap Result Status [%s]", remoteAddress, port, status));
}
-
+ } finally {
+ channelStatus = ChannelStatus.CLOSED;
+ readChannelDiscard();
closeQuietly(channel.socket());
closeQuietly(channel);
- closed = true;
+ logOperation("Close Completed");
}
}
- private int copyFromAppDataBuffer(final byte[] buffer, final int offset, final int len) {
- // If any data already exists in the application data buffer, copy it to the buffer.
- final ByteBuffer appDataBuffer = appDataManager.prepareForRead(1);
-
- final int appDataRemaining = appDataBuffer.remaining();
- if (appDataRemaining > 0) {
- final int bytesToCopy = Math.min(len, appDataBuffer.remaining());
- appDataBuffer.get(buffer, offset, bytesToCopy);
-
- final int bytesCopied = appDataRemaining - appDataBuffer.remaining();
- logger.trace("{} Copied {} ({}) bytes from unencrypted application buffer to user space",
- this, bytesToCopy, bytesCopied);
- return bytesCopied;
- }
- return 0;
- }
-
+ /**
+ * Get application bytes available for reading
+ *
+ * @return Number of application bytes available for reading
+ * @throws IOException Thrown on failures checking for available bytes
+ */
public int available() throws IOException {
ByteBuffer appDataBuffer = appDataManager.prepareForRead(1);
ByteBuffer streamDataBuffer = streamInManager.prepareForRead(1);
@@ -506,8 +298,7 @@ public class SSLSocketChannel implements Closeable {
return buffered;
}
- final boolean wasAbleToRead = isDataAvailable();
- if (!wasAbleToRead) {
+ if (!isDataAvailable()) {
return 0;
}
@@ -516,6 +307,12 @@ public class SSLSocketChannel implements Closeable {
return appDataBuffer.remaining() + streamDataBuffer.remaining();
}
+ /**
+ * Is data available for reading
+ *
+ * @return Data available status
+ * @throws IOException Thrown on SocketChannel.read() failures
+ */
public boolean isDataAvailable() throws IOException {
final ByteBuffer appDataBuffer = appDataManager.prepareForRead(1);
final ByteBuffer streamDataBuffer = streamInManager.prepareForRead(1);
@@ -529,101 +326,139 @@ public class SSLSocketChannel implements Closeable {
return (bytesRead > 0);
}
+ /**
+ * Read and return one byte
+ *
+ * @return Byte read or -1 when end of stream reached
+ * @throws IOException Thrown on read failures
+ */
public int read() throws IOException {
- final int bytesRead = read(oneByteBuffer);
- if (bytesRead == -1) {
- return -1;
+ final byte[] buffer = new byte[1];
+
+ final int bytesRead = read(buffer);
+ if (bytesRead == END_OF_STREAM) {
+ return END_OF_STREAM;
}
- return oneByteBuffer[0] & 0xFF;
+
+ return Byte.toUnsignedInt(buffer[0]);
}
+ /**
+ * Read available bytes into buffer
+ *
+ * @param buffer Byte array buffer
+ * @return Number of bytes read
+ * @throws IOException Thrown on read failures
+ */
public int read(final byte[] buffer) throws IOException {
return read(buffer, 0, buffer.length);
}
+ /**
+ * Read available bytes into buffer based on offset and length requested
+ *
+ * @param buffer Byte array buffer
+ * @param offset Buffer offset
+ * @param len Length of bytes to read
+ * @return Number of bytes read
+ * @throws IOException Thrown on read failures
+ */
public int read(final byte[] buffer, final int offset, final int len) throws IOException {
- logger.debug("{} Reading up to {} bytes of data", this, len);
+ logOperationBytes("Read Requested", len);
+ checkChannelStatus();
- if (!connected) {
- connect();
+ int applicationBytesRead = readApplicationBuffer(buffer, offset, len);
+ if (applicationBytesRead > 0) {
+ return applicationBytesRead;
}
-
- int copied = copyFromAppDataBuffer(buffer, offset, len);
- if (copied > 0) {
- return copied;
- }
-
appDataManager.clear();
while (true) {
- // prepare buffers and call unwrap
- final ByteBuffer streamInBuffer = streamInManager.prepareForRead(1);
- SSLEngineResult unwrapResponse = null;
- final ByteBuffer appDataBuffer = appDataManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
- unwrapResponse = engine.unwrap(streamInBuffer, appDataBuffer);
- logger.trace("{} When reading data, (handshake={}) Unwrap response: {}", this, handshaking, unwrapResponse);
+ final SSLEngineResult unwrapResult = unwrap();
- switch (unwrapResponse.getStatus()) {
+ if (SSLEngineResult.HandshakeStatus.FINISHED == unwrapResult.getHandshakeStatus()) {
+ // RFC 8446 Section 4.6 describes Post-Handshake Messages for TLS 1.3
+ logOperation("Processing Post-Handshake Messages");
+ continue;
+ }
+
+ final Status status = unwrapResult.getStatus();
+ switch (status) {
case BUFFER_OVERFLOW:
- throw new SSLHandshakeException("Buffer Overflow, which is not allowed to happen from an unwrap");
- case BUFFER_UNDERFLOW: {
-// appDataManager.prepareForRead(engine.getSession().getApplicationBufferSize());
-
- final ByteBuffer writableInBuffer = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
- final int bytesRead = readData(writableInBuffer);
- if (bytesRead < 0) {
- return -1;
+ throw new IllegalStateException(String.format("SSLEngineResult Status [%s] not allowed from unwrap", status));
+ case BUFFER_UNDERFLOW:
+ final ByteBuffer streamBuffer = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
+ final int channelBytesRead = readChannel(streamBuffer);
+ logOperationBytes("Channel Read Completed", channelBytesRead);
+ if (channelBytesRead == END_OF_STREAM) {
+ return END_OF_STREAM;
}
-
- continue;
- }
+ break;
case CLOSED:
- copied = copyFromAppDataBuffer(buffer, offset, len);
- if (copied == 0) {
- return -1;
+ applicationBytesRead = readApplicationBuffer(buffer, offset, len);
+ if (applicationBytesRead == 0) {
+ return END_OF_STREAM;
}
streamInManager.compact();
- return copied;
- case OK: {
- copied = copyFromAppDataBuffer(buffer, offset, len);
- if (copied == 0) {
- throw new IOException("Failed to decrypt data");
+ return applicationBytesRead;
+ case OK:
+ applicationBytesRead = readApplicationBuffer(buffer, offset, len);
+ if (applicationBytesRead == 0) {
+ throw new IOException("Read Application Buffer Failed");
}
streamInManager.compact();
- return copied;
- }
+ return applicationBytesRead;
}
}
}
+ /**
+ * Write one byte to channel
+ *
+ * @param data Byte to be written
+ * @throws IOException Thrown on write failures
+ */
public void write(final int data) throws IOException {
write(new byte[]{(byte) data}, 0, 1);
}
+ /**
+ * Write bytes to channel
+ *
+ * @param data Byte array to be written
+ * @throws IOException Thrown on write failures
+ */
public void write(final byte[] data) throws IOException {
write(data, 0, data.length);
}
+ /**
+ * Write data to channel performs multiple iterations based on data length
+ *
+ * @param data Byte array to be written
+ * @param offset Byte array offset
+ * @param len Length of bytes for writing
+ * @throws IOException Thrown on write failures
+ */
public void write(final byte[] data, final int offset, final int len) throws IOException {
- logger.debug("{} Writing {} bytes of data", this, len);
+ logOperationBytes("Write Started", len);
+ checkChannelStatus();
- if (!connected) {
- connect();
- }
-
- int iterations = len / MAX_WRITE_SIZE;
- if (len % MAX_WRITE_SIZE > 0) {
+ final int applicationBufferSize = engine.getSession().getApplicationBufferSize();
+ logOperationBytes("Write Application Buffer Size", applicationBufferSize);
+ int iterations = len / applicationBufferSize;
+ if (len % applicationBufferSize > 0) {
iterations++;
}
for (int i = 0; i < iterations; i++) {
streamOutManager.clear();
- final int itrOffset = offset + i * MAX_WRITE_SIZE;
- final int itrLen = Math.min(len - itrOffset, MAX_WRITE_SIZE);
+ final int itrOffset = offset + i * applicationBufferSize;
+ final int itrLen = Math.min(len - itrOffset, applicationBufferSize);
final ByteBuffer byteBuffer = ByteBuffer.wrap(data, itrOffset, itrLen);
- final BufferStateManager buffMan = new BufferStateManager(byteBuffer, Direction.READ);
- final Status status = encryptAndWriteFully(buffMan);
+ final BufferStateManager bufferStateManager = new BufferStateManager(byteBuffer, Direction.READ);
+ final Status status = wrapWriteChannel(bufferStateManager);
switch (status) {
case BUFFER_OVERFLOW:
streamOutManager.ensureSize(engine.getSession().getPacketBufferSize());
@@ -639,7 +474,294 @@ public class SSLSocketChannel implements Closeable {
}
}
+ /**
+ * Interrupt processing and disable transmission
+ */
public void interrupt() {
this.interrupted = true;
}
+
+ private void performHandshake() throws IOException {
+ logOperation("Handshake Started");
+ channelStatus = ChannelStatus.HANDSHAKING;
+ engine.beginHandshake();
+
+ SSLEngineResult.HandshakeStatus handshakeStatus = engine.getHandshakeStatus();
+ while (true) {
+ logHandshakeStatus(handshakeStatus);
+
+ switch (handshakeStatus) {
+ case FINISHED:
+ case NOT_HANDSHAKING:
+ channelStatus = ChannelStatus.ESTABLISHED;
+ final SSLSession session = engine.getSession();
+ LOGGER.debug("[{}:{}] [{}] Negotiated Protocol [{}] Cipher Suite [{}]",
+ remoteAddress,
+ port,
+ channelStatus,
+ session.getProtocol(),
+ session.getCipherSuite()
+ );
+ return;
+ case NEED_TASK:
+ runDelegatedTasks();
+ handshakeStatus = engine.getHandshakeStatus();
+ break;
+ case NEED_UNWRAP:
+ final SSLEngineResult unwrapResult = unwrap();
+ handshakeStatus = unwrapResult.getHandshakeStatus();
+ Status unwrapResultStatus = unwrapResult.getStatus();
+
+ if (unwrapResultStatus == Status.BUFFER_UNDERFLOW) {
+ final ByteBuffer writableDataIn = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
+ final int bytesRead = readChannel(writableDataIn);
+ logOperationBytes("Handshake Channel Read", bytesRead);
+
+ if (bytesRead == END_OF_STREAM) {
+ throw getHandshakeException(handshakeStatus, "End of Stream Found");
+ }
+ } else if (unwrapResultStatus == Status.CLOSED) {
+ throw getHandshakeException(handshakeStatus, "Channel Closed");
+ } else {
+ streamInManager.compact();
+ appDataManager.clear();
+ }
+ break;
+ case NEED_WRAP:
+ final ByteBuffer outboundBuffer = streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
+ final SSLEngineResult wrapResult = wrap(ByteBuffer.wrap(EMPTY_MESSAGE), outboundBuffer);
+ handshakeStatus = wrapResult.getHandshakeStatus();
+ final Status wrapResultStatus = wrapResult.getStatus();
+
+ if (wrapResultStatus == Status.BUFFER_OVERFLOW) {
+ streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
+ } else if (wrapResultStatus == Status.OK) {
+ final ByteBuffer streamBuffer = streamOutManager.prepareForRead(1);
+ final int bytesRemaining = streamBuffer.remaining();
+ writeChannel(streamBuffer);
+ logOperationBytes("Handshake Channel Write Completed", bytesRemaining);
+ streamOutManager.clear();
+ } else {
+ throw getHandshakeException(handshakeStatus, String.format("Wrap Failed [%s]", wrapResult.getStatus()));
+ }
+ break;
+ }
+ }
+ }
+
+ private int readChannel(final ByteBuffer outputBuffer) throws IOException {
+ logOperation("Channel Read Started");
+
+ final long started = System.currentTimeMillis();
+ long sleepNanoseconds = INITIAL_INCREMENTAL_SLEEP;
+ while (true) {
+ checkInterrupted();
+
+ if (outputBuffer.remaining() == 0) {
+ return 0;
+ }
+
+ final int channelBytesRead = channel.read(outputBuffer);
+ if (channelBytesRead == 0) {
+ checkTimeoutExceeded(started);
+ sleepNanoseconds = incrementalSleep(sleepNanoseconds);
+ continue;
+ }
+
+ return channelBytesRead;
+ }
+ }
+
+ private void writeChannel(final ByteBuffer inputBuffer) throws IOException {
+ long lastWriteCompleted = System.currentTimeMillis();
+
+ int totalBytes = 0;
+ long sleepNanoseconds = INITIAL_INCREMENTAL_SLEEP;
+ while (inputBuffer.hasRemaining()) {
+ checkInterrupted();
+
+ final int written = channel.write(inputBuffer);
+ totalBytes += written;
+
+ if (written > 0) {
+ lastWriteCompleted = System.currentTimeMillis();
+ } else {
+ checkTimeoutExceeded(lastWriteCompleted);
+ sleepNanoseconds = incrementalSleep(sleepNanoseconds);
+ }
+ }
+
+ logOperationBytes("Channel Write Completed", totalBytes);
+ }
+
+ private long incrementalSleep(final long nanoseconds) throws IOException {
+ try {
+ TimeUnit.NANOSECONDS.sleep(nanoseconds);
+ } catch (final InterruptedException e) {
+ close();
+ Thread.currentThread().interrupt();
+ throw new ClosedByInterruptException();
+ }
+ return Math.min(nanoseconds * 2, BUFFER_FULL_EMPTY_WAIT_NANOS);
+ }
+
+ private void readChannelDiscard() {
+ try {
+ final ByteBuffer readBuffer = ByteBuffer.allocate(DISCARD_BUFFER_LENGTH);
+ int bytesRead = channel.read(readBuffer);
+ while (bytesRead > 0) {
+ readBuffer.clear();
+ bytesRead = channel.read(readBuffer);
+ }
+ } catch (final IOException e) {
+ LOGGER.debug("[{}:{}] Read Channel Discard Failed", remoteAddress, port, e);
+ }
+ }
+
+ private int readApplicationBuffer(final byte[] buffer, final int offset, final int len) {
+ logOperationBytes("Application Buffer Read Requested", len);
+ final ByteBuffer appDataBuffer = appDataManager.prepareForRead(len);
+
+ final int appDataRemaining = appDataBuffer.remaining();
+ logOperationBytes("Application Buffer Remaining", appDataRemaining);
+ if (appDataRemaining > 0) {
+ final int bytesToCopy = Math.min(len, appDataBuffer.remaining());
+ appDataBuffer.get(buffer, offset, bytesToCopy);
+
+ final int bytesCopied = appDataRemaining - appDataBuffer.remaining();
+ logOperationBytes("Application Buffer Copied", bytesCopied);
+ return bytesCopied;
+ }
+ return 0;
+ }
+
+ private Status wrapWriteChannel(final BufferStateManager inputManager) throws IOException {
+ final ByteBuffer inputBuffer = inputManager.prepareForRead(0);
+ final ByteBuffer outputBuffer = streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
+
+ logOperationBytes("Wrap Started", inputBuffer.remaining());
+ Status status = Status.OK;
+ while (inputBuffer.remaining() > 0) {
+ final SSLEngineResult result = wrap(inputBuffer, outputBuffer);
+ status = result.getStatus();
+ if (status == Status.OK) {
+ final ByteBuffer readableOutBuff = streamOutManager.prepareForRead(0);
+ writeChannel(readableOutBuff);
+ streamOutManager.clear();
+ } else {
+ break;
+ }
+ }
+
+ return status;
+ }
+
+ private SSLEngineResult wrap(final ByteBuffer inputBuffer, final ByteBuffer outputBuffer) throws SSLException {
+ final SSLEngineResult result = engine.wrap(inputBuffer, outputBuffer);
+ logEngineResult(result, "WRAP Completed");
+ return result;
+ }
+
+ private SSLEngineResult unwrap() throws IOException {
+ final ByteBuffer streamBuffer = streamInManager.prepareForRead(engine.getSession().getPacketBufferSize());
+ final ByteBuffer applicationBuffer = appDataManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
+ final SSLEngineResult result = engine.unwrap(streamBuffer, applicationBuffer);
+ logEngineResult(result, "UNWRAP Completed");
+ return result;
+ }
+
+ private void runDelegatedTasks() {
+ Runnable delegatedTask;
+ while ((delegatedTask = engine.getDelegatedTask()) != null) {
+ logOperation("Running Delegated Task");
+ delegatedTask.run();
+ }
+ }
+
+ private void closeQuietly(final Closeable closeable) {
+ try {
+ closeable.close();
+ } catch (final Exception e) {
+ logOperation(String.format("Close failed: %s", e.getMessage()));
+ }
+ }
+
+ private SSLHandshakeException getHandshakeException(final SSLEngineResult.HandshakeStatus handshakeStatus, final String message) {
+ final String formatted = String.format("[%s:%d] Handshake Status [%s] %s", remoteAddress, port, handshakeStatus, message);
+ return new SSLHandshakeException(formatted);
+ }
+
+ private void checkChannelStatus() throws IOException {
+ if (ChannelStatus.ESTABLISHED != channelStatus) {
+ connect();
+ }
+ }
+
+ private void checkInterrupted() {
+ if (interrupted) {
+ throw new TransmissionDisabledException();
+ }
+ }
+
+ private void checkTimeoutExceeded(final long started) throws SocketTimeoutException {
+ if (System.currentTimeMillis() > started + timeoutMillis) {
+ throw new SocketTimeoutException(String.format("Timeout Exceeded [%d ms] for [%s:%d]", timeoutMillis, remoteAddress, port));
+ }
+ }
+
+ private void logOperation(final String operation) {
+ LOGGER.trace("[{}:{}] [{}] {}", remoteAddress, port, channelStatus, operation);
+ }
+
+ private void logOperationBytes(final String operation, final int bytes) {
+ LOGGER.trace("[{}:{}] [{}] {} Bytes [{}]", remoteAddress, port, channelStatus, operation, bytes);
+ }
+
+ private void logHandshakeStatus(final SSLEngineResult.HandshakeStatus handshakeStatus) {
+ LOGGER.trace("[{}:{}] [{}] Handshake Status [{}]", remoteAddress, port, channelStatus, handshakeStatus);
+ }
+
+ private void logEngineResult(final SSLEngineResult result, final String method) {
+ LOGGER.trace("[{}:{}] [{}] {} Status [{}] Handshake Status [{}] Produced [{}] Consumed [{}]",
+ remoteAddress,
+ port,
+ channelStatus,
+ method,
+ result.getStatus(),
+ result.getHandshakeStatus(),
+ result.bytesProduced(),
+ result.bytesConsumed()
+ );
+ }
+
+ private static SocketChannel createSocketChannel(final InetAddress bindAddress) throws IOException {
+ final SocketChannel socketChannel = SocketChannel.open();
+ if (bindAddress != null) {
+ final SocketAddress socketAddress = new InetSocketAddress(bindAddress, 0);
+ socketChannel.bind(socketAddress);
+ }
+ socketChannel.configureBlocking(false);
+ return socketChannel;
+ }
+
+ private static SSLEngine createEngine(final SSLContext sslContext, final boolean useClientMode) {
+ final SSLEngine sslEngine = sslContext.createSSLEngine();
+ sslEngine.setUseClientMode(useClientMode);
+ sslEngine.setNeedClientAuth(CLIENT_AUTHENTICATION_REQUIRED);
+ return sslEngine;
+ }
+
+ private enum ChannelStatus {
+ DISCONNECTED,
+
+ CONNECTING,
+
+ CONNECTED,
+
+ HANDSHAKING,
+
+ ESTABLISHED,
+
+ CLOSED
+ }
}
diff --git a/nifi-commons/nifi-security-socket-ssl/src/test/java/org/apache/nifi/remote/io/socket/ssl/SSLSocketChannelTest.java b/nifi-commons/nifi-security-socket-ssl/src/test/java/org/apache/nifi/remote/io/socket/ssl/SSLSocketChannelTest.java
new file mode 100644
index 0000000000..aa9dde595f
--- /dev/null
+++ b/nifi-commons/nifi-security-socket-ssl/src/test/java/org/apache/nifi/remote/io/socket/ssl/SSLSocketChannelTest.java
@@ -0,0 +1,315 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.nifi.remote.io.socket.ssl;
+
+import io.netty.bootstrap.Bootstrap;
+import io.netty.bootstrap.ServerBootstrap;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.ChannelPipeline;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.SimpleChannelInboundHandler;
+import io.netty.channel.nio.NioEventLoopGroup;
+import io.netty.channel.socket.nio.NioServerSocketChannel;
+import io.netty.channel.socket.nio.NioSocketChannel;
+import io.netty.handler.codec.DelimiterBasedFrameDecoder;
+import io.netty.handler.codec.Delimiters;
+import io.netty.handler.codec.string.StringDecoder;
+import io.netty.handler.codec.string.StringEncoder;
+import io.netty.handler.ssl.SslHandler;
+import org.apache.nifi.remote.io.socket.NetworkUtils;
+import org.apache.nifi.security.util.KeyStoreUtils;
+import org.apache.nifi.security.util.SslContextFactory;
+import org.apache.nifi.security.util.TlsConfiguration;
+import org.apache.nifi.security.util.TlsPlatform;
+import org.junit.Assume;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLEngine;
+import javax.net.ssl.SSLException;
+import java.io.File;
+import java.io.IOException;
+import java.io.UncheckedIOException;
+import java.net.InetSocketAddress;
+import java.net.SocketAddress;
+import java.nio.channels.ServerSocketChannel;
+import java.nio.channels.SocketChannel;
+import java.nio.charset.Charset;
+import java.nio.charset.StandardCharsets;
+import java.security.GeneralSecurityException;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.Executor;
+import java.util.concurrent.Executors;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Consumer;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+
+public class SSLSocketChannelTest {
+ private static final String LOCALHOST = "localhost";
+
+ private static final int GROUP_THREADS = 1;
+
+ private static final boolean CLIENT_CHANNEL = true;
+
+ private static final boolean SERVER_CHANNEL = false;
+
+ private static final int CHANNEL_TIMEOUT = 15000;
+
+ private static final int CHANNEL_FAILURE_TIMEOUT = 100;
+
+ private static final int CHANNEL_POLL_TIMEOUT = 5000;
+
+ private static final long CHANNEL_SLEEP_BEFORE_READ = 100;
+
+ private static final int MAX_MESSAGE_LENGTH = 1024;
+
+ private static final String TLS_1_3 = "TLSv1.3";
+
+ private static final String TLS_1_2 = "TLSv1.2";
+
+ private static final String MESSAGE = "PING\n";
+
+ private static final Charset MESSAGE_CHARSET = StandardCharsets.UTF_8;
+
+ private static final byte[] MESSAGE_BYTES = MESSAGE.getBytes(StandardCharsets.UTF_8);
+
+ private static final int FIRST_BYTE_OFFSET = 1;
+
+ private static SSLContext sslContext;
+
+ @BeforeClass
+ public static void setConfiguration() throws GeneralSecurityException, IOException {
+ final TlsConfiguration tlsConfiguration = KeyStoreUtils.createTlsConfigAndNewKeystoreTruststore();
+ new File(tlsConfiguration.getKeystorePath()).deleteOnExit();
+ new File(tlsConfiguration.getTruststorePath()).deleteOnExit();
+ sslContext = SslContextFactory.createSslContext(tlsConfiguration);
+ }
+
+ @Test
+ public void testClientConnectFailed() throws IOException {
+ final int port = NetworkUtils.getAvailableTcpPort();
+ final SSLSocketChannel sslSocketChannel = new SSLSocketChannel(sslContext, LOCALHOST, port, null, CLIENT_CHANNEL);
+ sslSocketChannel.setTimeout(CHANNEL_FAILURE_TIMEOUT);
+ assertThrows(Exception.class, sslSocketChannel::connect);
+ }
+
+ @Test
+ public void testClientConnectHandshakeFailed() throws IOException {
+ assumeProtocolSupported(TLS_1_2);
+ final EventLoopGroup group = new NioEventLoopGroup(GROUP_THREADS);
+
+ try (final SocketChannel socketChannel = SocketChannel.open()) {
+ final int port = NetworkUtils.getAvailableTcpPort();
+ startServer(group, port, TLS_1_2);
+
+ socketChannel.connect(new InetSocketAddress(LOCALHOST, port));
+ final SSLEngine sslEngine = createSslEngine(TLS_1_2, CLIENT_CHANNEL);
+
+ final SSLSocketChannel sslSocketChannel = new SSLSocketChannel(sslEngine, socketChannel);
+ sslSocketChannel.setTimeout(CHANNEL_FAILURE_TIMEOUT);
+
+ group.shutdownGracefully().syncUninterruptibly();
+ assertThrows(SSLException.class, sslSocketChannel::connect);
+ } finally {
+ group.shutdownGracefully().syncUninterruptibly();
+ }
+ }
+
+ @Test
+ public void testClientConnectWriteReadTls12() throws Exception {
+ assumeProtocolSupported(TLS_1_2);
+ assertChannelConnectedWriteReadClosed(TLS_1_2);
+ }
+
+ @Test
+ public void testClientConnectWriteReadTls13() throws Exception {
+ assumeProtocolSupported(TLS_1_3);
+ assertChannelConnectedWriteReadClosed(TLS_1_3);
+ }
+
+ @Test(timeout = CHANNEL_TIMEOUT)
+ public void testServerReadWriteTls12() throws Exception {
+ assumeProtocolSupported(TLS_1_2);
+ assertServerChannelConnectedReadClosed(TLS_1_2);
+ }
+
+ @Test(timeout = CHANNEL_TIMEOUT)
+ public void testServerReadWriteTls13() throws Exception {
+ assumeProtocolSupported(TLS_1_3);
+ assertServerChannelConnectedReadClosed(TLS_1_3);
+ }
+
+ private void assumeProtocolSupported(final String protocol) {
+ Assume.assumeTrue(String.format("Protocol [%s] not supported", protocol), TlsPlatform.getSupportedProtocols().contains(protocol));
+ }
+
+ private void assertServerChannelConnectedReadClosed(final String enabledProtocol) throws IOException, InterruptedException {
+ final int port = NetworkUtils.getAvailableTcpPort();
+ final ServerSocketChannel serverSocketChannel = ServerSocketChannel.open();
+ final SocketAddress socketAddress = new InetSocketAddress(LOCALHOST, port);
+ serverSocketChannel.bind(socketAddress);
+
+ final Executor executor = Executors.newSingleThreadExecutor();
+ final EventLoopGroup group = new NioEventLoopGroup(GROUP_THREADS);
+ try {
+ final Channel channel = startClient(group, port, enabledProtocol);
+
+ try {
+ final SocketChannel socketChannel = serverSocketChannel.accept();
+ final SSLSocketChannel sslSocketChannel = new SSLSocketChannel(sslContext, socketChannel, SERVER_CHANNEL);
+
+ final BlockingQueue queue = new LinkedBlockingQueue<>();
+ final Runnable readCommand = () -> {
+ final byte[] messageBytes = new byte[MESSAGE_BYTES.length];
+ try {
+ final int messageBytesRead = sslSocketChannel.read(messageBytes);
+ if (messageBytesRead == MESSAGE_BYTES.length) {
+ queue.add(new String(messageBytes, MESSAGE_CHARSET));
+ }
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ };
+ executor.execute(readCommand);
+ channel.writeAndFlush(MESSAGE).syncUninterruptibly();
+
+ final String messageRead = queue.poll(CHANNEL_POLL_TIMEOUT, TimeUnit.MILLISECONDS);
+ assertEquals("Message not matched", MESSAGE, messageRead);
+ } finally {
+ channel.close();
+ }
+ } finally {
+ group.shutdownGracefully().syncUninterruptibly();
+ serverSocketChannel.close();
+ }
+ }
+
+ private void assertChannelConnectedWriteReadClosed(final String enabledProtocol) throws IOException {
+ processClientSslSocketChannel(enabledProtocol, (sslSocketChannel -> {
+ try {
+ sslSocketChannel.connect();
+ assertFalse("Channel closed", sslSocketChannel.isClosed());
+
+ assertChannelWriteRead(sslSocketChannel);
+
+ sslSocketChannel.close();
+ assertTrue("Channel not closed", sslSocketChannel.isClosed());
+ } catch (final IOException e) {
+ throw new UncheckedIOException(String.format("Channel Failed for %s", enabledProtocol), e);
+ }
+ }));
+ }
+
+ private void assertChannelWriteRead(final SSLSocketChannel sslSocketChannel) throws IOException {
+ sslSocketChannel.write(MESSAGE_BYTES);
+
+ while (sslSocketChannel.available() == 0) {
+ try {
+ TimeUnit.MILLISECONDS.sleep(CHANNEL_SLEEP_BEFORE_READ);
+ } catch (final InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ final byte firstByteRead = (byte) sslSocketChannel.read();
+ assertEquals("Channel Message first byte not matched", MESSAGE_BYTES[0], firstByteRead);
+
+ final byte[] messageBytes = new byte[MESSAGE_BYTES.length];
+ messageBytes[0] = firstByteRead;
+
+ final int messageBytesRead = sslSocketChannel.read(messageBytes, FIRST_BYTE_OFFSET, messageBytes.length);
+ assertEquals("Channel Message Bytes Read not matched", messageBytes.length - FIRST_BYTE_OFFSET, messageBytesRead);
+
+ final String message = new String(messageBytes, MESSAGE_CHARSET);
+ assertEquals("Channel Message not matched", MESSAGE, message);
+ }
+
+ private void processClientSslSocketChannel(final String enabledProtocol, final Consumer channelConsumer) throws IOException {
+ final EventLoopGroup group = new NioEventLoopGroup(GROUP_THREADS);
+
+ try {
+ final int port = NetworkUtils.getAvailableTcpPort();
+ startServer(group, port, enabledProtocol);
+ final SSLSocketChannel sslSocketChannel = new SSLSocketChannel(sslContext, LOCALHOST, port, null, CLIENT_CHANNEL);
+ sslSocketChannel.setTimeout(CHANNEL_TIMEOUT);
+ channelConsumer.accept(sslSocketChannel);
+ } finally {
+ group.shutdownGracefully().syncUninterruptibly();
+ }
+ }
+
+ private Channel startClient(final EventLoopGroup group, final int port, final String enabledProtocol) {
+ final Bootstrap bootstrap = new Bootstrap();
+ bootstrap.group(group);
+ bootstrap.channel(NioSocketChannel.class);
+ bootstrap.handler(new ChannelInitializer() {
+ @Override
+ protected void initChannel(final Channel channel) {
+ final ChannelPipeline pipeline = channel.pipeline();
+ final SSLEngine sslEngine = createSslEngine(enabledProtocol, CLIENT_CHANNEL);
+ setPipelineHandlers(pipeline, sslEngine);
+ }
+ });
+ return bootstrap.connect(LOCALHOST, port).syncUninterruptibly().channel();
+ }
+
+ private void startServer(final EventLoopGroup group, final int port, final String enabledProtocol) {
+ final ServerBootstrap bootstrap = new ServerBootstrap();
+ bootstrap.group(group);
+ bootstrap.channel(NioServerSocketChannel.class);
+ bootstrap.childHandler(new ChannelInitializer() {
+ @Override
+ protected void initChannel(final Channel channel) {
+ final ChannelPipeline pipeline = channel.pipeline();
+ final SSLEngine sslEngine = createSslEngine(enabledProtocol, SERVER_CHANNEL);
+ setPipelineHandlers(pipeline, sslEngine);
+ pipeline.addLast(new SimpleChannelInboundHandler() {
+ @Override
+ protected void channelRead0(ChannelHandlerContext channelHandlerContext, String s) throws Exception {
+ channelHandlerContext.channel().writeAndFlush(MESSAGE).sync();
+ }
+ });
+ }
+ });
+
+ final ChannelFuture bindFuture = bootstrap.bind(LOCALHOST, port);
+ bindFuture.syncUninterruptibly();
+ }
+
+ private SSLEngine createSslEngine(final String enabledProtocol, final boolean useClientMode) {
+ final SSLEngine sslEngine = sslContext.createSSLEngine();
+ sslEngine.setUseClientMode(useClientMode);
+ sslEngine.setEnabledProtocols(new String[]{enabledProtocol});
+ return sslEngine;
+ }
+
+ private void setPipelineHandlers(final ChannelPipeline pipeline, final SSLEngine sslEngine) {
+ pipeline.addLast(new SslHandler(sslEngine));
+ pipeline.addLast(new DelimiterBasedFrameDecoder(MAX_MESSAGE_LENGTH, Delimiters.lineDelimiter()));
+ pipeline.addLast(new StringDecoder());
+ pipeline.addLast(new StringEncoder());
+ }
+}
diff --git a/nifi-nar-bundles/nifi-extension-utils/nifi-processor-utils/src/main/java/org/apache/nifi/processor/util/put/sender/SSLSocketChannelSender.java b/nifi-nar-bundles/nifi-extension-utils/nifi-processor-utils/src/main/java/org/apache/nifi/processor/util/put/sender/SSLSocketChannelSender.java
index 70771f1cbd..e2f05cc34b 100644
--- a/nifi-nar-bundles/nifi-extension-utils/nifi-processor-utils/src/main/java/org/apache/nifi/processor/util/put/sender/SSLSocketChannelSender.java
+++ b/nifi-nar-bundles/nifi-extension-utils/nifi-processor-utils/src/main/java/org/apache/nifi/processor/util/put/sender/SSLSocketChannelSender.java
@@ -68,9 +68,10 @@ public class SSLSocketChannelSender extends SocketChannelSender {
@Override
public void close() {
- super.close();
- IOUtils.closeQuietly(sslOutputStream);
+ // Close SSLSocketChannel before closing other resources
IOUtils.closeQuietly(sslChannel);
+ IOUtils.closeQuietly(sslOutputStream);
+ super.close();
sslChannel = null;
}