diff --git a/nifi-commons/nifi-security-socket-ssl/pom.xml b/nifi-commons/nifi-security-socket-ssl/pom.xml index c4edb99be7..b4592a5864 100644 --- a/nifi-commons/nifi-security-socket-ssl/pom.xml +++ b/nifi-commons/nifi-security-socket-ssl/pom.xml @@ -31,5 +31,17 @@ org.slf4j slf4j-api + + org.apache.nifi + nifi-security-utils + 1.14.0-SNAPSHOT + test + + + io.netty + netty-handler + 4.1.65.Final + test + diff --git a/nifi-commons/nifi-security-socket-ssl/src/main/java/org/apache/nifi/remote/io/socket/ssl/SSLSocketChannel.java b/nifi-commons/nifi-security-socket-ssl/src/main/java/org/apache/nifi/remote/io/socket/ssl/SSLSocketChannel.java index 59902d05f5..9a5cdd8b50 100644 --- a/nifi-commons/nifi-security-socket-ssl/src/main/java/org/apache/nifi/remote/io/socket/ssl/SSLSocketChannel.java +++ b/nifi-commons/nifi-security-socket-ssl/src/main/java/org/apache/nifi/remote/io/socket/ssl/SSLSocketChannel.java @@ -26,8 +26,9 @@ import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngineResult; import javax.net.ssl.SSLEngineResult.Status; +import javax.net.ssl.SSLException; import javax.net.ssl.SSLHandshakeException; -import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSession; import java.io.Closeable; import java.io.IOException; import java.net.InetAddress; @@ -38,89 +39,87 @@ import java.net.SocketTimeoutException; import java.nio.ByteBuffer; import java.nio.channels.ClosedByInterruptException; import java.nio.channels.SocketChannel; -import java.security.cert.Certificate; -import java.security.cert.CertificateException; -import java.security.cert.X509Certificate; import java.util.concurrent.TimeUnit; +/** + * SSLSocketChannel supports reading and writing bytes using TLS and NIO SocketChannels with configurable timeouts + */ public class SSLSocketChannel implements Closeable { + private static final Logger LOGGER = LoggerFactory.getLogger(SSLSocketChannel.class); - public static final int MAX_WRITE_SIZE = 65536; - - private static final Logger logger = LoggerFactory.getLogger(SSLSocketChannel.class); + private static final int DISCARD_BUFFER_LENGTH = 8192; + private static final int END_OF_STREAM = -1; + private static final byte[] EMPTY_MESSAGE = new byte[0]; private static final long BUFFER_FULL_EMPTY_WAIT_NANOS = TimeUnit.NANOSECONDS.convert(1, TimeUnit.MILLISECONDS); + private static final long FINISH_CONNECT_SLEEP = 50; + private static final long INITIAL_INCREMENTAL_SLEEP = 1; + private static final boolean CLIENT_AUTHENTICATION_REQUIRED = true; private final String remoteAddress; private final int port; private final SSLEngine engine; private final SocketAddress socketAddress; - - private BufferStateManager streamInManager; - private BufferStateManager streamOutManager; - private BufferStateManager appDataManager; - - private SocketChannel channel; - - private final byte[] oneByteBuffer = new byte[1]; - + private final BufferStateManager streamInManager; + private final BufferStateManager streamOutManager; + private final BufferStateManager appDataManager; + private final SocketChannel channel; private int timeoutMillis = 30000; - private volatile boolean connected = false; - private boolean handshaking = false; - private boolean closed = false; + private volatile boolean interrupted = false; + private volatile ChannelStatus channelStatus = ChannelStatus.DISCONNECTED; - public SSLSocketChannel(final SSLContext sslContext, final String hostname, final int port, final InetAddress localAddress, final boolean client) throws IOException { - this.socketAddress = new InetSocketAddress(hostname, port); - this.channel = SocketChannel.open(); - if (localAddress != null) { - final SocketAddress localSocketAddress = new InetSocketAddress(localAddress, 0); - this.channel.bind(localSocketAddress); - } - this.remoteAddress = hostname; + /** + * SSLSocketChannel constructor with SSLContext and remote address parameters + * + * @param sslContext SSLContext used to create SSLEngine with specified client mode + * @param remoteAddress Remote Address used for connection + * @param port Remote Port used for connection + * @param bindAddress Local address used for binding server channel when provided + * @param useClientMode Use Client Mode + * @throws IOException Thrown on failures creating Socket Channel + */ + public SSLSocketChannel(final SSLContext sslContext, final String remoteAddress, final int port, final InetAddress bindAddress, final boolean useClientMode) throws IOException { + this.engine = createEngine(sslContext, useClientMode); + this.channel = createSocketChannel(bindAddress); + this.socketAddress = new InetSocketAddress(remoteAddress, port); + this.remoteAddress = remoteAddress; this.port = port; - this.engine = sslContext.createSSLEngine(); - this.engine.setUseClientMode(client); - engine.setNeedClientAuth(true); streamInManager = new BufferStateManager(ByteBuffer.allocate(engine.getSession().getPacketBufferSize())); streamOutManager = new BufferStateManager(ByteBuffer.allocate(engine.getSession().getPacketBufferSize())); appDataManager = new BufferStateManager(ByteBuffer.allocate(engine.getSession().getApplicationBufferSize())); } - public SSLSocketChannel(final SSLContext sslContext, final SocketChannel socketChannel, final boolean client) throws IOException { - if (!socketChannel.isConnected()) { - throw new IllegalArgumentException("Cannot pass an un-connected SocketChannel"); - } - - this.channel = socketChannel; - - this.socketAddress = socketChannel.getRemoteAddress(); - final Socket socket = socketChannel.socket(); - this.remoteAddress = socket.getInetAddress().toString(); - this.port = socket.getPort(); - - this.engine = sslContext.createSSLEngine(); - this.engine.setUseClientMode(client); - this.engine.setNeedClientAuth(true); - - streamInManager = new BufferStateManager(ByteBuffer.allocate(engine.getSession().getPacketBufferSize())); - streamOutManager = new BufferStateManager(ByteBuffer.allocate(engine.getSession().getPacketBufferSize())); - appDataManager = new BufferStateManager(ByteBuffer.allocate(engine.getSession().getApplicationBufferSize())); + /** + * SSLSocketChannel constructor with SSLContext and connected SocketChannel + * + * @param sslContext SSLContext used to create SSLEngine with specified client mode + * @param socketChannel Connected SocketChannel + * @param useClientMode Use Client Mode + * @throws IOException Thrown on SocketChannel.getRemoteAddress() + */ + public SSLSocketChannel(final SSLContext sslContext, final SocketChannel socketChannel, final boolean useClientMode) throws IOException { + this(createEngine(sslContext, useClientMode), socketChannel); } + /** + * SSLSocketChannel constructor with configured SSLEngine and connected SocketChannel + * + * @param sslEngine SSLEngine configured with mode and client authentication + * @param socketChannel Connected SocketChannel + * @throws IOException Thrown on SocketChannel.getRemoteAddress() + */ public SSLSocketChannel(final SSLEngine sslEngine, final SocketChannel socketChannel) throws IOException { if (!socketChannel.isConnected()) { - throw new IllegalArgumentException("Cannot pass an un-connected SocketChannel"); + throw new IllegalArgumentException("Connected SocketChannel required"); } + socketChannel.configureBlocking(false); this.channel = socketChannel; - this.socketAddress = socketChannel.getRemoteAddress(); final Socket socket = socketChannel.socket(); this.remoteAddress = socket.getInetAddress().toString(); this.port = socket.getPort(); - - // don't set useClientMode or needClientAuth, use the engine as is and let the caller configure it this.engine = sslEngine; streamInManager = new BufferStateManager(ByteBuffer.allocate(engine.getSession().getPacketBufferSize())); @@ -128,166 +127,64 @@ public class SSLSocketChannel implements Closeable { appDataManager = new BufferStateManager(ByteBuffer.allocate(engine.getSession().getApplicationBufferSize())); } - public void setTimeout(final int millis) { - this.timeoutMillis = millis; + public void setTimeout(final int timeoutMillis) { + this.timeoutMillis = timeoutMillis; } public int getTimeout() { return timeoutMillis; } + /** + * Connect Channel when not connected and perform handshake process + * + * @throws IOException Thrown on connection failures + */ public void connect() throws IOException { + channelStatus = ChannelStatus.CONNECTING; + try { - channel.configureBlocking(false); if (!channel.isConnected()) { - final long startTime = System.currentTimeMillis(); + logOperation("Connection Started"); + final long started = System.currentTimeMillis(); if (!channel.connect(socketAddress)) { while (!channel.finishConnect()) { - if (interrupted) { - throw new TransmissionDisabledException(); - } - if (System.currentTimeMillis() > startTime + timeoutMillis) { - throw new SocketTimeoutException("Timed out connecting to " + remoteAddress + ":" + port); - } + checkInterrupted(); + checkTimeoutExceeded(started); try { - Thread.sleep(50L); + TimeUnit.MILLISECONDS.sleep(FINISH_CONNECT_SLEEP); } catch (final InterruptedException e) { + logOperation("Connection Interrupted"); } } } } - engine.beginHandshake(); + channelStatus = ChannelStatus.CONNECTED; + } catch (final Exception e) { + close(); + throw new SSLException(String.format("[%s:%d] Connection Failed", remoteAddress, port), e); + } + try { performHandshake(); - logger.debug("{} Successfully completed SSL handshake", this); - - streamInManager.clear(); - streamOutManager.clear(); - appDataManager.clear(); - - connected = true; - } catch (final Exception e) { - logger.error("{} failed to connect", this, e); - closeQuietly(channel); - engine.closeInbound(); - engine.closeOutbound(); - throw e; - } - } - - public String getDn() throws CertificateException, SSLPeerUnverifiedException { - final Certificate[] certs = engine.getSession().getPeerCertificates(); - if (certs == null || certs.length == 0) { - throw new SSLPeerUnverifiedException("No certificates found"); - } - - final Certificate certificate = certs[0]; - if (certificate instanceof X509Certificate) { - final X509Certificate peerCertificate = (X509Certificate) certificate; - peerCertificate.checkValidity(); - return peerCertificate.getSubjectDN().getName().trim(); - } else { - throw new CertificateException(String.format("X.509 Certificate class not found [%s]", certificate.getClass())); - } - } - - private void performHandshake() throws IOException { - // Generate handshake message - final byte[] emptyMessage = new byte[0]; - handshaking = true; - logger.debug("{} Performing Handshake", this); - - try { - while (true) { - switch (engine.getHandshakeStatus()) { - case FINISHED: - return; - case NEED_WRAP: { - final ByteBuffer appDataOut = ByteBuffer.wrap(emptyMessage); - - final ByteBuffer outboundBuffer = streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize()); - - final SSLEngineResult wrapHelloResult = engine.wrap(appDataOut, outboundBuffer); - if (wrapHelloResult.getStatus() == Status.BUFFER_OVERFLOW) { - streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize()); - continue; - } - - if (wrapHelloResult.getStatus() != Status.OK) { - throw new SSLHandshakeException("Could not generate SSL Handshake information: SSLEngineResult: " - + wrapHelloResult.toString()); - } - - logger.trace("{} Handshake response after wrapping: {}", this, wrapHelloResult); - - final ByteBuffer readableStreamOut = streamOutManager.prepareForRead(1); - final int bytesToSend = readableStreamOut.remaining(); - writeFully(readableStreamOut); - logger.trace("{} Sent {} bytes of wrapped data for handshake", this, bytesToSend); - - streamOutManager.clear(); - } - continue; - case NEED_UNWRAP: { - final ByteBuffer readableDataIn = streamInManager.prepareForRead(0); - final ByteBuffer appData = appDataManager.prepareForWrite(engine.getSession().getApplicationBufferSize()); - - // Read handshake response from other side - logger.trace("{} Unwrapping: {} to {}", this, readableDataIn, appData); - SSLEngineResult handshakeResponseResult = engine.unwrap(readableDataIn, appData); - logger.trace("{} Handshake response after unwrapping: {}", this, handshakeResponseResult); - - if (handshakeResponseResult.getStatus() == Status.BUFFER_UNDERFLOW) { - final ByteBuffer writableDataIn = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize()); - final int bytesRead = readData(writableDataIn); - if (bytesRead > 0) { - logger.trace("{} Read {} bytes for handshake", this, bytesRead); - } - - if (bytesRead < 0) { - throw new SSLHandshakeException("Reached End-of-File marker while performing handshake"); - } - } else if (handshakeResponseResult.getStatus() == Status.CLOSED) { - throw new IOException("Channel was closed by peer during handshake"); - } else { - streamInManager.compact(); - appDataManager.clear(); - } - } - break; - case NEED_TASK: - performTasks(); - continue; - case NOT_HANDSHAKING: - return; - } - } - } finally { - handshaking = false; - } - } - - private void performTasks() { - Runnable runnable; - while ((runnable = engine.getDelegatedTask()) != null) { - runnable.run(); - } - } - - private void closeQuietly(final Closeable closeable) { - try { - closeable.close(); - } catch (final Exception e) { + } catch (final IOException e) { + close(); + throw new SSLException(String.format("[%s:%d] Handshake Failed", remoteAddress, port), e); } } + /** + * Shutdown Socket Channel input and read available bytes + * + * @throws IOException Thrown on Socket Channel failures + */ public void consume() throws IOException { channel.shutdownInput(); - final byte[] b = new byte[4096]; - final ByteBuffer buffer = ByteBuffer.wrap(b); + final byte[] byteBuffer = new byte[DISCARD_BUFFER_LENGTH]; + final ByteBuffer buffer = ByteBuffer.wrap(byteBuffer); int readCount; do { readCount = channel.read(buffer); @@ -295,209 +192,104 @@ public class SSLSocketChannel implements Closeable { } while (readCount > 0); } - private int readData(final ByteBuffer dest) throws IOException { - final long startTime = System.currentTimeMillis(); - - while (true) { - if (interrupted) { - throw new TransmissionDisabledException(); - } - - if (dest.remaining() == 0) { - return 0; - } - - final int readCount = channel.read(dest); - - long sleepNanos = 1L; - if (readCount == 0) { - if (System.currentTimeMillis() > startTime + timeoutMillis) { - throw new SocketTimeoutException("Timed out reading from socket connected to " + remoteAddress + ":" + port); - } - try { - TimeUnit.NANOSECONDS.sleep(sleepNanos); - } catch (InterruptedException e) { - close(); - Thread.currentThread().interrupt(); // set the interrupt status - throw new ClosedByInterruptException(); - } - - sleepNanos = Math.min(sleepNanos * 2, BUFFER_FULL_EMPTY_WAIT_NANOS); - - continue; - } - - logger.trace("{} Read {} bytes", this, readCount); - return readCount; - } - } - - private Status encryptAndWriteFully(final BufferStateManager src) throws IOException { - SSLEngineResult result = null; - - final ByteBuffer buff = src.prepareForRead(0); - final ByteBuffer outBuff = streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize()); - - logger.trace("{} Encrypting {} bytes", this, buff.remaining()); - while (buff.remaining() > 0) { - result = engine.wrap(buff, outBuff); - if (result.getStatus() == Status.OK) { - final ByteBuffer readableOutBuff = streamOutManager.prepareForRead(0); - writeFully(readableOutBuff); - streamOutManager.clear(); - } else { - return result.getStatus(); - } - } - - return result.getStatus(); - } - - private void writeFully(final ByteBuffer src) throws IOException { - long lastByteWrittenTime = System.currentTimeMillis(); - - int bytesWritten = 0; - while (src.hasRemaining()) { - if (interrupted) { - throw new TransmissionDisabledException(); - } - - final int written = channel.write(src); - bytesWritten += written; - final long now = System.currentTimeMillis(); - long sleepNanos = 1L; - - if (written > 0) { - lastByteWrittenTime = now; - } else { - if (now > lastByteWrittenTime + timeoutMillis) { - throw new SocketTimeoutException("Timed out writing to socket connected to " + remoteAddress + ":" + port); - } - try { - TimeUnit.NANOSECONDS.sleep(sleepNanos); - } catch (final InterruptedException e) { - close(); - Thread.currentThread().interrupt(); // set the interrupt status - throw new ClosedByInterruptException(); - } - - sleepNanos = Math.min(sleepNanos * 2, BUFFER_FULL_EMPTY_WAIT_NANOS); - } - } - - logger.trace("{} Wrote {} bytes", this, bytesWritten); - } - + /** + * Is Channel Closed + * + * @return Channel Closed Status + */ public boolean isClosed() { - if (closed) { + if (ChannelStatus.CLOSED == channelStatus) { return true; } - // need to detect if peer has sent closure handshake...if so the answer is true - final ByteBuffer writableInBuffer = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize()); - int readCount = 0; - try { - readCount = channel.read(writableInBuffer); - } catch (IOException e) { - logger.error("{} failed to read data", this, e); - readCount = -1; // treat the condition same as if End of Stream - } - if (readCount == 0) { - return false; - } - if (readCount > 0) { - logger.trace("{} Read {} bytes", this, readCount); - final ByteBuffer streamInBuffer = streamInManager.prepareForRead(1); - final ByteBuffer appDataBuffer = appDataManager.prepareForWrite(engine.getSession().getApplicationBufferSize()); + // Read Channel to determine closed status + final ByteBuffer inputBuffer = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize()); + int bytesRead; + try { + bytesRead = channel.read(inputBuffer); + } catch (final IOException e) { + LOGGER.warn("[{}:{}] Closed Status Read Failed", remoteAddress, port, e); + bytesRead = END_OF_STREAM; + } + logOperationBytes("Closed Status Read", bytesRead); + + if (bytesRead == 0) { + return false; + } else if (bytesRead > 0) { try { - SSLEngineResult unwrapResponse = engine.unwrap(streamInBuffer, appDataBuffer); - logger.trace("{} When checking if closed, (handshake={}) Unwrap response: {}", this, handshaking, unwrapResponse); - if (unwrapResponse.getStatus().equals(Status.CLOSED)) { - // Drain the incoming TCP buffer - final ByteBuffer discardBuffer = ByteBuffer.allocate(8192); - int bytesDiscarded = channel.read(discardBuffer); - while (bytesDiscarded > 0) { - discardBuffer.clear(); - bytesDiscarded = channel.read(discardBuffer); - } + final SSLEngineResult unwrapResult = unwrap(); + if (Status.CLOSED == unwrapResult.getStatus()) { + readChannelDiscard(); engine.closeInbound(); } else { streamInManager.compact(); return false; } - } catch (IOException e) { - logger.error("{} failed to check if closed. Closing channel.", this, e); + } catch (final IOException e) { + LOGGER.warn("[{}:{}] Closed Status Unwrap Failed", remoteAddress, port, e); } } - // either readCount is -1, indicating an end of stream, or the peer sent a closure handshake - // so go ahead and close down the channel - closeQuietly(channel.socket()); - closeQuietly(channel); - closed = true; + + // Close Channel when encountering end of stream or closed status + try { + close(); + } catch (final IOException e) { + LOGGER.warn("[{}:{}] Close Failed", remoteAddress, port, e); + } return true; } + /** + * Close Channel and process notifications + * + * @throws IOException Thrown on SSLEngine.wrap() failures + */ @Override public void close() throws IOException { - logger.debug("{} Closing Connection", this); - if (channel == null) { - return; - } - - if (closed) { + logOperation("Close Requested"); + if (channelStatus == ChannelStatus.CLOSED) { return; } try { engine.closeOutbound(); - final byte[] emptyMessage = new byte[0]; - - final ByteBuffer appDataOut = ByteBuffer.wrap(emptyMessage); - final ByteBuffer outboundBuffer = streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize()); - final SSLEngineResult handshakeResult = engine.wrap(appDataOut, outboundBuffer); - - if (handshakeResult.getStatus() != Status.CLOSED) { - throw new IOException("Invalid close state - will not send network data"); + streamOutManager.clear(); + final ByteBuffer inputBuffer = ByteBuffer.wrap(EMPTY_MESSAGE); + final ByteBuffer outputBuffer = streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize()); + SSLEngineResult wrapResult = wrap(inputBuffer, outputBuffer); + Status status = wrapResult.getStatus(); + if (Status.OK == status) { + logOperation("Clearing Outbound Buffer"); + outputBuffer.clear(); + wrapResult = wrap(inputBuffer, outputBuffer); + status = wrapResult.getStatus(); } - - final ByteBuffer readableStreamOut = streamOutManager.prepareForRead(1); - writeFully(readableStreamOut); - } finally { - // Drain the incoming TCP buffer - final ByteBuffer discardBuffer = ByteBuffer.allocate(8192); - try { - int bytesDiscarded = channel.read(discardBuffer); - while (bytesDiscarded > 0) { - discardBuffer.clear(); - bytesDiscarded = channel.read(discardBuffer); + if (Status.CLOSED == status) { + final ByteBuffer streamOutputBuffer = streamOutManager.prepareForRead(1); + try { + writeChannel(streamOutputBuffer); + } catch (final IOException e) { + logOperation(String.format("Write Close Notification Failed: %s", e.getMessage())); } - } catch (Exception e) { + } else { + throw new SSLException(String.format("[%s:%d] Invalid Wrap Result Status [%s]", remoteAddress, port, status)); } - + } finally { + channelStatus = ChannelStatus.CLOSED; + readChannelDiscard(); closeQuietly(channel.socket()); closeQuietly(channel); - closed = true; + logOperation("Close Completed"); } } - private int copyFromAppDataBuffer(final byte[] buffer, final int offset, final int len) { - // If any data already exists in the application data buffer, copy it to the buffer. - final ByteBuffer appDataBuffer = appDataManager.prepareForRead(1); - - final int appDataRemaining = appDataBuffer.remaining(); - if (appDataRemaining > 0) { - final int bytesToCopy = Math.min(len, appDataBuffer.remaining()); - appDataBuffer.get(buffer, offset, bytesToCopy); - - final int bytesCopied = appDataRemaining - appDataBuffer.remaining(); - logger.trace("{} Copied {} ({}) bytes from unencrypted application buffer to user space", - this, bytesToCopy, bytesCopied); - return bytesCopied; - } - return 0; - } - + /** + * Get application bytes available for reading + * + * @return Number of application bytes available for reading + * @throws IOException Thrown on failures checking for available bytes + */ public int available() throws IOException { ByteBuffer appDataBuffer = appDataManager.prepareForRead(1); ByteBuffer streamDataBuffer = streamInManager.prepareForRead(1); @@ -506,8 +298,7 @@ public class SSLSocketChannel implements Closeable { return buffered; } - final boolean wasAbleToRead = isDataAvailable(); - if (!wasAbleToRead) { + if (!isDataAvailable()) { return 0; } @@ -516,6 +307,12 @@ public class SSLSocketChannel implements Closeable { return appDataBuffer.remaining() + streamDataBuffer.remaining(); } + /** + * Is data available for reading + * + * @return Data available status + * @throws IOException Thrown on SocketChannel.read() failures + */ public boolean isDataAvailable() throws IOException { final ByteBuffer appDataBuffer = appDataManager.prepareForRead(1); final ByteBuffer streamDataBuffer = streamInManager.prepareForRead(1); @@ -529,101 +326,139 @@ public class SSLSocketChannel implements Closeable { return (bytesRead > 0); } + /** + * Read and return one byte + * + * @return Byte read or -1 when end of stream reached + * @throws IOException Thrown on read failures + */ public int read() throws IOException { - final int bytesRead = read(oneByteBuffer); - if (bytesRead == -1) { - return -1; + final byte[] buffer = new byte[1]; + + final int bytesRead = read(buffer); + if (bytesRead == END_OF_STREAM) { + return END_OF_STREAM; } - return oneByteBuffer[0] & 0xFF; + + return Byte.toUnsignedInt(buffer[0]); } + /** + * Read available bytes into buffer + * + * @param buffer Byte array buffer + * @return Number of bytes read + * @throws IOException Thrown on read failures + */ public int read(final byte[] buffer) throws IOException { return read(buffer, 0, buffer.length); } + /** + * Read available bytes into buffer based on offset and length requested + * + * @param buffer Byte array buffer + * @param offset Buffer offset + * @param len Length of bytes to read + * @return Number of bytes read + * @throws IOException Thrown on read failures + */ public int read(final byte[] buffer, final int offset, final int len) throws IOException { - logger.debug("{} Reading up to {} bytes of data", this, len); + logOperationBytes("Read Requested", len); + checkChannelStatus(); - if (!connected) { - connect(); + int applicationBytesRead = readApplicationBuffer(buffer, offset, len); + if (applicationBytesRead > 0) { + return applicationBytesRead; } - - int copied = copyFromAppDataBuffer(buffer, offset, len); - if (copied > 0) { - return copied; - } - appDataManager.clear(); while (true) { - // prepare buffers and call unwrap - final ByteBuffer streamInBuffer = streamInManager.prepareForRead(1); - SSLEngineResult unwrapResponse = null; - final ByteBuffer appDataBuffer = appDataManager.prepareForWrite(engine.getSession().getApplicationBufferSize()); - unwrapResponse = engine.unwrap(streamInBuffer, appDataBuffer); - logger.trace("{} When reading data, (handshake={}) Unwrap response: {}", this, handshaking, unwrapResponse); + final SSLEngineResult unwrapResult = unwrap(); - switch (unwrapResponse.getStatus()) { + if (SSLEngineResult.HandshakeStatus.FINISHED == unwrapResult.getHandshakeStatus()) { + // RFC 8446 Section 4.6 describes Post-Handshake Messages for TLS 1.3 + logOperation("Processing Post-Handshake Messages"); + continue; + } + + final Status status = unwrapResult.getStatus(); + switch (status) { case BUFFER_OVERFLOW: - throw new SSLHandshakeException("Buffer Overflow, which is not allowed to happen from an unwrap"); - case BUFFER_UNDERFLOW: { -// appDataManager.prepareForRead(engine.getSession().getApplicationBufferSize()); - - final ByteBuffer writableInBuffer = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize()); - final int bytesRead = readData(writableInBuffer); - if (bytesRead < 0) { - return -1; + throw new IllegalStateException(String.format("SSLEngineResult Status [%s] not allowed from unwrap", status)); + case BUFFER_UNDERFLOW: + final ByteBuffer streamBuffer = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize()); + final int channelBytesRead = readChannel(streamBuffer); + logOperationBytes("Channel Read Completed", channelBytesRead); + if (channelBytesRead == END_OF_STREAM) { + return END_OF_STREAM; } - - continue; - } + break; case CLOSED: - copied = copyFromAppDataBuffer(buffer, offset, len); - if (copied == 0) { - return -1; + applicationBytesRead = readApplicationBuffer(buffer, offset, len); + if (applicationBytesRead == 0) { + return END_OF_STREAM; } streamInManager.compact(); - return copied; - case OK: { - copied = copyFromAppDataBuffer(buffer, offset, len); - if (copied == 0) { - throw new IOException("Failed to decrypt data"); + return applicationBytesRead; + case OK: + applicationBytesRead = readApplicationBuffer(buffer, offset, len); + if (applicationBytesRead == 0) { + throw new IOException("Read Application Buffer Failed"); } streamInManager.compact(); - return copied; - } + return applicationBytesRead; } } } + /** + * Write one byte to channel + * + * @param data Byte to be written + * @throws IOException Thrown on write failures + */ public void write(final int data) throws IOException { write(new byte[]{(byte) data}, 0, 1); } + /** + * Write bytes to channel + * + * @param data Byte array to be written + * @throws IOException Thrown on write failures + */ public void write(final byte[] data) throws IOException { write(data, 0, data.length); } + /** + * Write data to channel performs multiple iterations based on data length + * + * @param data Byte array to be written + * @param offset Byte array offset + * @param len Length of bytes for writing + * @throws IOException Thrown on write failures + */ public void write(final byte[] data, final int offset, final int len) throws IOException { - logger.debug("{} Writing {} bytes of data", this, len); + logOperationBytes("Write Started", len); + checkChannelStatus(); - if (!connected) { - connect(); - } - - int iterations = len / MAX_WRITE_SIZE; - if (len % MAX_WRITE_SIZE > 0) { + final int applicationBufferSize = engine.getSession().getApplicationBufferSize(); + logOperationBytes("Write Application Buffer Size", applicationBufferSize); + int iterations = len / applicationBufferSize; + if (len % applicationBufferSize > 0) { iterations++; } for (int i = 0; i < iterations; i++) { streamOutManager.clear(); - final int itrOffset = offset + i * MAX_WRITE_SIZE; - final int itrLen = Math.min(len - itrOffset, MAX_WRITE_SIZE); + final int itrOffset = offset + i * applicationBufferSize; + final int itrLen = Math.min(len - itrOffset, applicationBufferSize); final ByteBuffer byteBuffer = ByteBuffer.wrap(data, itrOffset, itrLen); - final BufferStateManager buffMan = new BufferStateManager(byteBuffer, Direction.READ); - final Status status = encryptAndWriteFully(buffMan); + final BufferStateManager bufferStateManager = new BufferStateManager(byteBuffer, Direction.READ); + final Status status = wrapWriteChannel(bufferStateManager); switch (status) { case BUFFER_OVERFLOW: streamOutManager.ensureSize(engine.getSession().getPacketBufferSize()); @@ -639,7 +474,294 @@ public class SSLSocketChannel implements Closeable { } } + /** + * Interrupt processing and disable transmission + */ public void interrupt() { this.interrupted = true; } + + private void performHandshake() throws IOException { + logOperation("Handshake Started"); + channelStatus = ChannelStatus.HANDSHAKING; + engine.beginHandshake(); + + SSLEngineResult.HandshakeStatus handshakeStatus = engine.getHandshakeStatus(); + while (true) { + logHandshakeStatus(handshakeStatus); + + switch (handshakeStatus) { + case FINISHED: + case NOT_HANDSHAKING: + channelStatus = ChannelStatus.ESTABLISHED; + final SSLSession session = engine.getSession(); + LOGGER.debug("[{}:{}] [{}] Negotiated Protocol [{}] Cipher Suite [{}]", + remoteAddress, + port, + channelStatus, + session.getProtocol(), + session.getCipherSuite() + ); + return; + case NEED_TASK: + runDelegatedTasks(); + handshakeStatus = engine.getHandshakeStatus(); + break; + case NEED_UNWRAP: + final SSLEngineResult unwrapResult = unwrap(); + handshakeStatus = unwrapResult.getHandshakeStatus(); + Status unwrapResultStatus = unwrapResult.getStatus(); + + if (unwrapResultStatus == Status.BUFFER_UNDERFLOW) { + final ByteBuffer writableDataIn = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize()); + final int bytesRead = readChannel(writableDataIn); + logOperationBytes("Handshake Channel Read", bytesRead); + + if (bytesRead == END_OF_STREAM) { + throw getHandshakeException(handshakeStatus, "End of Stream Found"); + } + } else if (unwrapResultStatus == Status.CLOSED) { + throw getHandshakeException(handshakeStatus, "Channel Closed"); + } else { + streamInManager.compact(); + appDataManager.clear(); + } + break; + case NEED_WRAP: + final ByteBuffer outboundBuffer = streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize()); + final SSLEngineResult wrapResult = wrap(ByteBuffer.wrap(EMPTY_MESSAGE), outboundBuffer); + handshakeStatus = wrapResult.getHandshakeStatus(); + final Status wrapResultStatus = wrapResult.getStatus(); + + if (wrapResultStatus == Status.BUFFER_OVERFLOW) { + streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize()); + } else if (wrapResultStatus == Status.OK) { + final ByteBuffer streamBuffer = streamOutManager.prepareForRead(1); + final int bytesRemaining = streamBuffer.remaining(); + writeChannel(streamBuffer); + logOperationBytes("Handshake Channel Write Completed", bytesRemaining); + streamOutManager.clear(); + } else { + throw getHandshakeException(handshakeStatus, String.format("Wrap Failed [%s]", wrapResult.getStatus())); + } + break; + } + } + } + + private int readChannel(final ByteBuffer outputBuffer) throws IOException { + logOperation("Channel Read Started"); + + final long started = System.currentTimeMillis(); + long sleepNanoseconds = INITIAL_INCREMENTAL_SLEEP; + while (true) { + checkInterrupted(); + + if (outputBuffer.remaining() == 0) { + return 0; + } + + final int channelBytesRead = channel.read(outputBuffer); + if (channelBytesRead == 0) { + checkTimeoutExceeded(started); + sleepNanoseconds = incrementalSleep(sleepNanoseconds); + continue; + } + + return channelBytesRead; + } + } + + private void writeChannel(final ByteBuffer inputBuffer) throws IOException { + long lastWriteCompleted = System.currentTimeMillis(); + + int totalBytes = 0; + long sleepNanoseconds = INITIAL_INCREMENTAL_SLEEP; + while (inputBuffer.hasRemaining()) { + checkInterrupted(); + + final int written = channel.write(inputBuffer); + totalBytes += written; + + if (written > 0) { + lastWriteCompleted = System.currentTimeMillis(); + } else { + checkTimeoutExceeded(lastWriteCompleted); + sleepNanoseconds = incrementalSleep(sleepNanoseconds); + } + } + + logOperationBytes("Channel Write Completed", totalBytes); + } + + private long incrementalSleep(final long nanoseconds) throws IOException { + try { + TimeUnit.NANOSECONDS.sleep(nanoseconds); + } catch (final InterruptedException e) { + close(); + Thread.currentThread().interrupt(); + throw new ClosedByInterruptException(); + } + return Math.min(nanoseconds * 2, BUFFER_FULL_EMPTY_WAIT_NANOS); + } + + private void readChannelDiscard() { + try { + final ByteBuffer readBuffer = ByteBuffer.allocate(DISCARD_BUFFER_LENGTH); + int bytesRead = channel.read(readBuffer); + while (bytesRead > 0) { + readBuffer.clear(); + bytesRead = channel.read(readBuffer); + } + } catch (final IOException e) { + LOGGER.debug("[{}:{}] Read Channel Discard Failed", remoteAddress, port, e); + } + } + + private int readApplicationBuffer(final byte[] buffer, final int offset, final int len) { + logOperationBytes("Application Buffer Read Requested", len); + final ByteBuffer appDataBuffer = appDataManager.prepareForRead(len); + + final int appDataRemaining = appDataBuffer.remaining(); + logOperationBytes("Application Buffer Remaining", appDataRemaining); + if (appDataRemaining > 0) { + final int bytesToCopy = Math.min(len, appDataBuffer.remaining()); + appDataBuffer.get(buffer, offset, bytesToCopy); + + final int bytesCopied = appDataRemaining - appDataBuffer.remaining(); + logOperationBytes("Application Buffer Copied", bytesCopied); + return bytesCopied; + } + return 0; + } + + private Status wrapWriteChannel(final BufferStateManager inputManager) throws IOException { + final ByteBuffer inputBuffer = inputManager.prepareForRead(0); + final ByteBuffer outputBuffer = streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize()); + + logOperationBytes("Wrap Started", inputBuffer.remaining()); + Status status = Status.OK; + while (inputBuffer.remaining() > 0) { + final SSLEngineResult result = wrap(inputBuffer, outputBuffer); + status = result.getStatus(); + if (status == Status.OK) { + final ByteBuffer readableOutBuff = streamOutManager.prepareForRead(0); + writeChannel(readableOutBuff); + streamOutManager.clear(); + } else { + break; + } + } + + return status; + } + + private SSLEngineResult wrap(final ByteBuffer inputBuffer, final ByteBuffer outputBuffer) throws SSLException { + final SSLEngineResult result = engine.wrap(inputBuffer, outputBuffer); + logEngineResult(result, "WRAP Completed"); + return result; + } + + private SSLEngineResult unwrap() throws IOException { + final ByteBuffer streamBuffer = streamInManager.prepareForRead(engine.getSession().getPacketBufferSize()); + final ByteBuffer applicationBuffer = appDataManager.prepareForWrite(engine.getSession().getApplicationBufferSize()); + final SSLEngineResult result = engine.unwrap(streamBuffer, applicationBuffer); + logEngineResult(result, "UNWRAP Completed"); + return result; + } + + private void runDelegatedTasks() { + Runnable delegatedTask; + while ((delegatedTask = engine.getDelegatedTask()) != null) { + logOperation("Running Delegated Task"); + delegatedTask.run(); + } + } + + private void closeQuietly(final Closeable closeable) { + try { + closeable.close(); + } catch (final Exception e) { + logOperation(String.format("Close failed: %s", e.getMessage())); + } + } + + private SSLHandshakeException getHandshakeException(final SSLEngineResult.HandshakeStatus handshakeStatus, final String message) { + final String formatted = String.format("[%s:%d] Handshake Status [%s] %s", remoteAddress, port, handshakeStatus, message); + return new SSLHandshakeException(formatted); + } + + private void checkChannelStatus() throws IOException { + if (ChannelStatus.ESTABLISHED != channelStatus) { + connect(); + } + } + + private void checkInterrupted() { + if (interrupted) { + throw new TransmissionDisabledException(); + } + } + + private void checkTimeoutExceeded(final long started) throws SocketTimeoutException { + if (System.currentTimeMillis() > started + timeoutMillis) { + throw new SocketTimeoutException(String.format("Timeout Exceeded [%d ms] for [%s:%d]", timeoutMillis, remoteAddress, port)); + } + } + + private void logOperation(final String operation) { + LOGGER.trace("[{}:{}] [{}] {}", remoteAddress, port, channelStatus, operation); + } + + private void logOperationBytes(final String operation, final int bytes) { + LOGGER.trace("[{}:{}] [{}] {} Bytes [{}]", remoteAddress, port, channelStatus, operation, bytes); + } + + private void logHandshakeStatus(final SSLEngineResult.HandshakeStatus handshakeStatus) { + LOGGER.trace("[{}:{}] [{}] Handshake Status [{}]", remoteAddress, port, channelStatus, handshakeStatus); + } + + private void logEngineResult(final SSLEngineResult result, final String method) { + LOGGER.trace("[{}:{}] [{}] {} Status [{}] Handshake Status [{}] Produced [{}] Consumed [{}]", + remoteAddress, + port, + channelStatus, + method, + result.getStatus(), + result.getHandshakeStatus(), + result.bytesProduced(), + result.bytesConsumed() + ); + } + + private static SocketChannel createSocketChannel(final InetAddress bindAddress) throws IOException { + final SocketChannel socketChannel = SocketChannel.open(); + if (bindAddress != null) { + final SocketAddress socketAddress = new InetSocketAddress(bindAddress, 0); + socketChannel.bind(socketAddress); + } + socketChannel.configureBlocking(false); + return socketChannel; + } + + private static SSLEngine createEngine(final SSLContext sslContext, final boolean useClientMode) { + final SSLEngine sslEngine = sslContext.createSSLEngine(); + sslEngine.setUseClientMode(useClientMode); + sslEngine.setNeedClientAuth(CLIENT_AUTHENTICATION_REQUIRED); + return sslEngine; + } + + private enum ChannelStatus { + DISCONNECTED, + + CONNECTING, + + CONNECTED, + + HANDSHAKING, + + ESTABLISHED, + + CLOSED + } } diff --git a/nifi-commons/nifi-security-socket-ssl/src/test/java/org/apache/nifi/remote/io/socket/ssl/SSLSocketChannelTest.java b/nifi-commons/nifi-security-socket-ssl/src/test/java/org/apache/nifi/remote/io/socket/ssl/SSLSocketChannelTest.java new file mode 100644 index 0000000000..aa9dde595f --- /dev/null +++ b/nifi-commons/nifi-security-socket-ssl/src/test/java/org/apache/nifi/remote/io/socket/ssl/SSLSocketChannelTest.java @@ -0,0 +1,315 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.remote.io.socket.ssl; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.DelimiterBasedFrameDecoder; +import io.netty.handler.codec.Delimiters; +import io.netty.handler.codec.string.StringDecoder; +import io.netty.handler.codec.string.StringEncoder; +import io.netty.handler.ssl.SslHandler; +import org.apache.nifi.remote.io.socket.NetworkUtils; +import org.apache.nifi.security.util.KeyStoreUtils; +import org.apache.nifi.security.util.SslContextFactory; +import org.apache.nifi.security.util.TlsConfiguration; +import org.apache.nifi.security.util.TlsPlatform; +import org.junit.Assume; +import org.junit.BeforeClass; +import org.junit.Test; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLException; +import java.io.File; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.security.GeneralSecurityException; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class SSLSocketChannelTest { + private static final String LOCALHOST = "localhost"; + + private static final int GROUP_THREADS = 1; + + private static final boolean CLIENT_CHANNEL = true; + + private static final boolean SERVER_CHANNEL = false; + + private static final int CHANNEL_TIMEOUT = 15000; + + private static final int CHANNEL_FAILURE_TIMEOUT = 100; + + private static final int CHANNEL_POLL_TIMEOUT = 5000; + + private static final long CHANNEL_SLEEP_BEFORE_READ = 100; + + private static final int MAX_MESSAGE_LENGTH = 1024; + + private static final String TLS_1_3 = "TLSv1.3"; + + private static final String TLS_1_2 = "TLSv1.2"; + + private static final String MESSAGE = "PING\n"; + + private static final Charset MESSAGE_CHARSET = StandardCharsets.UTF_8; + + private static final byte[] MESSAGE_BYTES = MESSAGE.getBytes(StandardCharsets.UTF_8); + + private static final int FIRST_BYTE_OFFSET = 1; + + private static SSLContext sslContext; + + @BeforeClass + public static void setConfiguration() throws GeneralSecurityException, IOException { + final TlsConfiguration tlsConfiguration = KeyStoreUtils.createTlsConfigAndNewKeystoreTruststore(); + new File(tlsConfiguration.getKeystorePath()).deleteOnExit(); + new File(tlsConfiguration.getTruststorePath()).deleteOnExit(); + sslContext = SslContextFactory.createSslContext(tlsConfiguration); + } + + @Test + public void testClientConnectFailed() throws IOException { + final int port = NetworkUtils.getAvailableTcpPort(); + final SSLSocketChannel sslSocketChannel = new SSLSocketChannel(sslContext, LOCALHOST, port, null, CLIENT_CHANNEL); + sslSocketChannel.setTimeout(CHANNEL_FAILURE_TIMEOUT); + assertThrows(Exception.class, sslSocketChannel::connect); + } + + @Test + public void testClientConnectHandshakeFailed() throws IOException { + assumeProtocolSupported(TLS_1_2); + final EventLoopGroup group = new NioEventLoopGroup(GROUP_THREADS); + + try (final SocketChannel socketChannel = SocketChannel.open()) { + final int port = NetworkUtils.getAvailableTcpPort(); + startServer(group, port, TLS_1_2); + + socketChannel.connect(new InetSocketAddress(LOCALHOST, port)); + final SSLEngine sslEngine = createSslEngine(TLS_1_2, CLIENT_CHANNEL); + + final SSLSocketChannel sslSocketChannel = new SSLSocketChannel(sslEngine, socketChannel); + sslSocketChannel.setTimeout(CHANNEL_FAILURE_TIMEOUT); + + group.shutdownGracefully().syncUninterruptibly(); + assertThrows(SSLException.class, sslSocketChannel::connect); + } finally { + group.shutdownGracefully().syncUninterruptibly(); + } + } + + @Test + public void testClientConnectWriteReadTls12() throws Exception { + assumeProtocolSupported(TLS_1_2); + assertChannelConnectedWriteReadClosed(TLS_1_2); + } + + @Test + public void testClientConnectWriteReadTls13() throws Exception { + assumeProtocolSupported(TLS_1_3); + assertChannelConnectedWriteReadClosed(TLS_1_3); + } + + @Test(timeout = CHANNEL_TIMEOUT) + public void testServerReadWriteTls12() throws Exception { + assumeProtocolSupported(TLS_1_2); + assertServerChannelConnectedReadClosed(TLS_1_2); + } + + @Test(timeout = CHANNEL_TIMEOUT) + public void testServerReadWriteTls13() throws Exception { + assumeProtocolSupported(TLS_1_3); + assertServerChannelConnectedReadClosed(TLS_1_3); + } + + private void assumeProtocolSupported(final String protocol) { + Assume.assumeTrue(String.format("Protocol [%s] not supported", protocol), TlsPlatform.getSupportedProtocols().contains(protocol)); + } + + private void assertServerChannelConnectedReadClosed(final String enabledProtocol) throws IOException, InterruptedException { + final int port = NetworkUtils.getAvailableTcpPort(); + final ServerSocketChannel serverSocketChannel = ServerSocketChannel.open(); + final SocketAddress socketAddress = new InetSocketAddress(LOCALHOST, port); + serverSocketChannel.bind(socketAddress); + + final Executor executor = Executors.newSingleThreadExecutor(); + final EventLoopGroup group = new NioEventLoopGroup(GROUP_THREADS); + try { + final Channel channel = startClient(group, port, enabledProtocol); + + try { + final SocketChannel socketChannel = serverSocketChannel.accept(); + final SSLSocketChannel sslSocketChannel = new SSLSocketChannel(sslContext, socketChannel, SERVER_CHANNEL); + + final BlockingQueue queue = new LinkedBlockingQueue<>(); + final Runnable readCommand = () -> { + final byte[] messageBytes = new byte[MESSAGE_BYTES.length]; + try { + final int messageBytesRead = sslSocketChannel.read(messageBytes); + if (messageBytesRead == MESSAGE_BYTES.length) { + queue.add(new String(messageBytes, MESSAGE_CHARSET)); + } + } catch (IOException e) { + throw new UncheckedIOException(e); + } + }; + executor.execute(readCommand); + channel.writeAndFlush(MESSAGE).syncUninterruptibly(); + + final String messageRead = queue.poll(CHANNEL_POLL_TIMEOUT, TimeUnit.MILLISECONDS); + assertEquals("Message not matched", MESSAGE, messageRead); + } finally { + channel.close(); + } + } finally { + group.shutdownGracefully().syncUninterruptibly(); + serverSocketChannel.close(); + } + } + + private void assertChannelConnectedWriteReadClosed(final String enabledProtocol) throws IOException { + processClientSslSocketChannel(enabledProtocol, (sslSocketChannel -> { + try { + sslSocketChannel.connect(); + assertFalse("Channel closed", sslSocketChannel.isClosed()); + + assertChannelWriteRead(sslSocketChannel); + + sslSocketChannel.close(); + assertTrue("Channel not closed", sslSocketChannel.isClosed()); + } catch (final IOException e) { + throw new UncheckedIOException(String.format("Channel Failed for %s", enabledProtocol), e); + } + })); + } + + private void assertChannelWriteRead(final SSLSocketChannel sslSocketChannel) throws IOException { + sslSocketChannel.write(MESSAGE_BYTES); + + while (sslSocketChannel.available() == 0) { + try { + TimeUnit.MILLISECONDS.sleep(CHANNEL_SLEEP_BEFORE_READ); + } catch (final InterruptedException e) { + throw new RuntimeException(e); + } + } + + final byte firstByteRead = (byte) sslSocketChannel.read(); + assertEquals("Channel Message first byte not matched", MESSAGE_BYTES[0], firstByteRead); + + final byte[] messageBytes = new byte[MESSAGE_BYTES.length]; + messageBytes[0] = firstByteRead; + + final int messageBytesRead = sslSocketChannel.read(messageBytes, FIRST_BYTE_OFFSET, messageBytes.length); + assertEquals("Channel Message Bytes Read not matched", messageBytes.length - FIRST_BYTE_OFFSET, messageBytesRead); + + final String message = new String(messageBytes, MESSAGE_CHARSET); + assertEquals("Channel Message not matched", MESSAGE, message); + } + + private void processClientSslSocketChannel(final String enabledProtocol, final Consumer channelConsumer) throws IOException { + final EventLoopGroup group = new NioEventLoopGroup(GROUP_THREADS); + + try { + final int port = NetworkUtils.getAvailableTcpPort(); + startServer(group, port, enabledProtocol); + final SSLSocketChannel sslSocketChannel = new SSLSocketChannel(sslContext, LOCALHOST, port, null, CLIENT_CHANNEL); + sslSocketChannel.setTimeout(CHANNEL_TIMEOUT); + channelConsumer.accept(sslSocketChannel); + } finally { + group.shutdownGracefully().syncUninterruptibly(); + } + } + + private Channel startClient(final EventLoopGroup group, final int port, final String enabledProtocol) { + final Bootstrap bootstrap = new Bootstrap(); + bootstrap.group(group); + bootstrap.channel(NioSocketChannel.class); + bootstrap.handler(new ChannelInitializer() { + @Override + protected void initChannel(final Channel channel) { + final ChannelPipeline pipeline = channel.pipeline(); + final SSLEngine sslEngine = createSslEngine(enabledProtocol, CLIENT_CHANNEL); + setPipelineHandlers(pipeline, sslEngine); + } + }); + return bootstrap.connect(LOCALHOST, port).syncUninterruptibly().channel(); + } + + private void startServer(final EventLoopGroup group, final int port, final String enabledProtocol) { + final ServerBootstrap bootstrap = new ServerBootstrap(); + bootstrap.group(group); + bootstrap.channel(NioServerSocketChannel.class); + bootstrap.childHandler(new ChannelInitializer() { + @Override + protected void initChannel(final Channel channel) { + final ChannelPipeline pipeline = channel.pipeline(); + final SSLEngine sslEngine = createSslEngine(enabledProtocol, SERVER_CHANNEL); + setPipelineHandlers(pipeline, sslEngine); + pipeline.addLast(new SimpleChannelInboundHandler() { + @Override + protected void channelRead0(ChannelHandlerContext channelHandlerContext, String s) throws Exception { + channelHandlerContext.channel().writeAndFlush(MESSAGE).sync(); + } + }); + } + }); + + final ChannelFuture bindFuture = bootstrap.bind(LOCALHOST, port); + bindFuture.syncUninterruptibly(); + } + + private SSLEngine createSslEngine(final String enabledProtocol, final boolean useClientMode) { + final SSLEngine sslEngine = sslContext.createSSLEngine(); + sslEngine.setUseClientMode(useClientMode); + sslEngine.setEnabledProtocols(new String[]{enabledProtocol}); + return sslEngine; + } + + private void setPipelineHandlers(final ChannelPipeline pipeline, final SSLEngine sslEngine) { + pipeline.addLast(new SslHandler(sslEngine)); + pipeline.addLast(new DelimiterBasedFrameDecoder(MAX_MESSAGE_LENGTH, Delimiters.lineDelimiter())); + pipeline.addLast(new StringDecoder()); + pipeline.addLast(new StringEncoder()); + } +} diff --git a/nifi-nar-bundles/nifi-extension-utils/nifi-processor-utils/src/main/java/org/apache/nifi/processor/util/put/sender/SSLSocketChannelSender.java b/nifi-nar-bundles/nifi-extension-utils/nifi-processor-utils/src/main/java/org/apache/nifi/processor/util/put/sender/SSLSocketChannelSender.java index 70771f1cbd..e2f05cc34b 100644 --- a/nifi-nar-bundles/nifi-extension-utils/nifi-processor-utils/src/main/java/org/apache/nifi/processor/util/put/sender/SSLSocketChannelSender.java +++ b/nifi-nar-bundles/nifi-extension-utils/nifi-processor-utils/src/main/java/org/apache/nifi/processor/util/put/sender/SSLSocketChannelSender.java @@ -68,9 +68,10 @@ public class SSLSocketChannelSender extends SocketChannelSender { @Override public void close() { - super.close(); - IOUtils.closeQuietly(sslOutputStream); + // Close SSLSocketChannel before closing other resources IOUtils.closeQuietly(sslChannel); + IOUtils.closeQuietly(sslOutputStream); + super.close(); sslChannel = null; }