NIFI-9761 Correct PeerChannel processing for TLS 1.3 (#5836)

* NIFI-9761 Corrected PeerChannel processing for TLS 1.3
- Added TestPeerChannel with methods for TLS 1.2 and TLS 1.3
- Updated PeerChannel.close() to process SSLEngine close notification
- Improved logging and corrected handling after decryption
This commit is contained in:
exceptionfactory 2022-03-09 14:15:52 -06:00 committed by GitHub
parent 6a1c7c72d5
commit c73573b325
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 385 additions and 43 deletions

View File

@ -264,6 +264,11 @@
<version>${nifi.groovy.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-handler</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.github.ben-manes.caffeine</groupId>
<artifactId>caffeine</artifactId>

View File

@ -23,6 +23,7 @@ import org.slf4j.LoggerFactory;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;
import java.io.Closeable;
import java.io.IOException;
import java.nio.ByteBuffer;
@ -30,6 +31,10 @@ import java.nio.channels.SocketChannel;
import java.util.OptionalInt;
public class PeerChannel implements Closeable {
private static final int END_OF_FILE = -1;
private static final int EMPTY_BUFFER = 0;
private static final Logger logger = LoggerFactory.getLogger(PeerChannel.class);
private final SocketChannel socketChannel;
@ -38,7 +43,7 @@ public class PeerChannel implements Closeable {
private final ByteBuffer singleByteBuffer = ByteBuffer.allocate(1);
private ByteBuffer destinationBuffer = ByteBuffer.allocate(16 * 1024); // buffer that SSLEngine is to write into
private ByteBuffer streamBuffer = ByteBuffer.allocate(16 * 1024); // buffer for data that is read from SocketChannel
private final ByteBuffer streamBuffer = ByteBuffer.allocate(16 * 1024); // buffer for data that is read from SocketChannel
private ByteBuffer applicationBuffer = ByteBuffer.allocate(0); // buffer for application-level data that is ready to be served up (i.e., already decrypted if necessary)
public PeerChannel(final SocketChannel socketChannel, final SSLEngine sslEngine, final String peerDescription) {
@ -47,10 +52,45 @@ public class PeerChannel implements Closeable {
this.peerDescription = peerDescription;
}
/**
* Close Socket Channel and process SSLEngine close notifications when configured
*
* @throws IOException Thrown on failure to close Socket Channel or process SSLEngine operations
*/
@Override
public void close() throws IOException {
socketChannel.close();
try {
if (sslEngine == null) {
logger.debug("Closing Peer Channel [{}] SSLEngine not configured", peerDescription);
} else {
logger.debug("Closing Peer Channel [{}] SSLEngine close started", peerDescription);
sslEngine.closeOutbound();
// Send TLS close notification packets available after initiating SSLEngine.closeOutbound()
final ByteBuffer inputBuffer = ByteBuffer.allocate(0);
final ByteBuffer outputBuffer = ByteBuffer.allocate(sslEngine.getSession().getPacketBufferSize());
SSLEngineResult wrapResult = sslEngine.wrap(inputBuffer, outputBuffer);
SSLEngineResult.Status status = wrapResult.getStatus();
outputBuffer.flip();
if (SSLEngineResult.Status.OK == status) {
write(outputBuffer);
outputBuffer.clear();
wrapResult = sslEngine.wrap(inputBuffer, outputBuffer);
status = wrapResult.getStatus();
}
if (SSLEngineResult.Status.CLOSED == status) {
write(outputBuffer);
} else {
throw new SSLException(String.format("Closing Peer Channel [%s] Invalid Wrap Result Status [%s]", peerDescription, status));
}
logger.debug("Closing Peer Channel [{}] SSLEngine close completed", peerDescription);
}
} finally {
logger.debug("Closing Peer Channel [{}] Socket Channel close started", peerDescription);
socketChannel.close();
}
}
public boolean isConnected() {
@ -65,6 +105,13 @@ public class PeerChannel implements Closeable {
return peerDescription;
}
/**
* Write one byte to the channel
*
* @param b Byte to be written
* @return Status of write operation returns true on success
* @throws IOException Thrown on failure to write to the Socket Channel
*/
public boolean write(final byte b) throws IOException {
singleByteBuffer.clear();
singleByteBuffer.put(b);
@ -75,13 +122,18 @@ public class PeerChannel implements Closeable {
return bytesWritten > 0;
}
/**
* Read one byte as an unsigned integer from the channel
*
* @return Returns empty when zero bytes are available and returns negative one when the channel is closed
* @throws IOException Thrown on failure to read from Socket Channel
*/
public OptionalInt read() throws IOException {
singleByteBuffer.clear();
final int bytesRead = read(singleByteBuffer);
if (bytesRead < 0) {
return OptionalInt.of(-1);
}
if (bytesRead == 0) {
return OptionalInt.of(END_OF_FILE);
} else if (bytesRead == EMPTY_BUFFER) {
return OptionalInt.empty();
}
@ -91,9 +143,6 @@ public class PeerChannel implements Closeable {
return OptionalInt.of(read & 0xFF);
}
/**
* Reads the given ByteBuffer of data and returns a new ByteBuffer (which is "flipped" / ready to be read). The newly returned
* ByteBuffer will be written to be written via the {@link #write(ByteBuffer)} method. I.e., it will have already been encrypted, if
@ -104,11 +153,11 @@ public class PeerChannel implements Closeable {
* @throws IOException if a failure occurs while encrypting the data
*/
public ByteBuffer prepareForWrite(final ByteBuffer plaintext) throws IOException {
logger.trace("Channel [{}] Buffer wrap started: Input Bytes [{}]", peerDescription, plaintext.remaining());
if (sslEngine == null) {
return plaintext;
}
ByteBuffer prepared = ByteBuffer.allocate(Math.min(85, plaintext.capacity() - plaintext.position()));
while (plaintext.hasRemaining()) {
encrypt(plaintext);
@ -125,14 +174,28 @@ public class PeerChannel implements Closeable {
}
prepared.flip();
logger.trace("Channel [{}] Buffer wrap completed: Prepared Bytes [{}]", peerDescription, prepared.remaining());
return prepared;
}
/**
* Write prepared buffer to Socket Channel
*
* @param preparedBuffer Buffer must contain bytes processed through prepareForWrite() when TLS is enabled
* @return Number of bytes written according to SocketChannel.write()
* @throws IOException Thrown on failure to write to the Socket Channel
*/
public int write(final ByteBuffer preparedBuffer) throws IOException {
return socketChannel.write(preparedBuffer);
}
/**
* Read application data bytes into the provided buffer
*
* @param dst Buffer to be populated with application data bytes
* @return Number of bytes read into the provided buffer
* @throws IOException Thrown on failure to read from the Socket Channel
*/
public int read(final ByteBuffer dst) throws IOException {
// If we have data ready to go, then go ahead and copy it.
final int bytesCopied = copy(applicationBuffer, dst);
@ -141,12 +204,11 @@ public class PeerChannel implements Closeable {
}
final int bytesRead = socketChannel.read(streamBuffer);
if (bytesRead < 1) {
return bytesRead;
}
if (bytesRead > 0) {
logger.trace("Read {} bytes from SocketChannel", bytesRead);
logger.trace("Channel [{}] Socket read completed: bytes [{}]", peerDescription, bytesRead);
if (bytesRead == END_OF_FILE) {
return END_OF_FILE;
} else if (streamBuffer.remaining() == EMPTY_BUFFER) {
return EMPTY_BUFFER;
}
streamBuffer.flip();
@ -157,7 +219,7 @@ public class PeerChannel implements Closeable {
return copy(applicationBuffer, dst);
} else {
final boolean decrypted = decrypt(streamBuffer);
logger.trace("Decryption after reading those bytes successful = {}", decrypted);
logger.trace("Channel [{}] Decryption completed [{}]", peerDescription, decrypted);
if (decrypted) {
cloneToApplicationBuffer(destinationBuffer);
@ -167,9 +229,8 @@ public class PeerChannel implements Closeable {
} else {
// Not enough data to decrypt. Compact the buffer so that we keep the data we have
// but prepare the buffer to be written to again.
logger.debug("Not enough data to decrypt. Will need to consume more data before decrypting");
streamBuffer.compact();
return 0;
logger.trace("Channel [{}] Socket Channel read required", peerDescription);
return EMPTY_BUFFER;
}
}
} finally {
@ -219,6 +280,7 @@ public class PeerChannel implements Closeable {
while (true) {
final SSLEngineResult result = sslEngine.wrap(plaintext, destinationBuffer);
logOperationResult("WRAP", result);
switch (result.getStatus()) {
case OK:
@ -240,9 +302,6 @@ public class PeerChannel implements Closeable {
}
}
/**
* Attempts to decrypt the given buffer of data, writing the result into {@link #destinationBuffer}. If successful, will return <code>true</code>.
* If more data is needed in order to perform the decryption, will return <code>false</code>.
@ -260,15 +319,21 @@ public class PeerChannel implements Closeable {
while (true) {
final SSLEngineResult result = sslEngine.unwrap(encrypted, destinationBuffer);
logOperationResult("UNWRAP", result);
switch (result.getStatus()) {
case OK:
if (SSLEngineResult.HandshakeStatus.FINISHED == result.getHandshakeStatus()) {
// RFC 8446 Section 4.6 describes Post-Handshake Messages for TLS 1.3
// Break out of switch statement to call SSLEngine.unwrap() again
break;
}
destinationBuffer.flip();
return true;
case CLOSED:
throw new IOException("Failed to decrypt data from Peer " + peerDescription + " because Peer unexpectedly closed connection");
case BUFFER_OVERFLOW:
// ecnryptedBuffer is not large enough. Need to increase the size.
// encryptedBuffer is not large enough. Need to increase the size.
final ByteBuffer tempBuffer = ByteBuffer.allocate(encrypted.position() + sslEngine.getSession().getApplicationBufferSize());
destinationBuffer.flip();
tempBuffer.put(destinationBuffer);
@ -282,7 +347,11 @@ public class PeerChannel implements Closeable {
}
}
/**
* Perform TLS handshake when SSLEngine configured
*
* @throws IOException Thrown on failure to handle socket communication or TLS packet processing
*/
public void performHandshake() throws IOException {
if (sslEngine == null) {
return;
@ -295,18 +364,17 @@ public class PeerChannel implements Closeable {
while (true) {
final SSLEngineResult.HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus();
logHandshakeStatus(handshakeStatus);
switch (handshakeStatus) {
case FINISHED:
case NOT_HANDSHAKING:
streamBuffer.clear();
destinationBuffer.clear();
logger.debug("Completed SSL Handshake with Peer {}", peerDescription);
logHandshakeCompleted();
return;
case NEED_TASK:
logger.debug("SSL Handshake with Peer {} Needs Task", peerDescription);
Runnable runnable;
while ((runnable = sslEngine.getDelegatedTask()) != null) {
runnable.run();
@ -314,27 +382,22 @@ public class PeerChannel implements Closeable {
break;
case NEED_WRAP:
logger.trace("SSL Handshake with Peer {} Needs Wrap", peerDescription);
encrypt(emptyMessage);
final int bytesWritten = write(destinationBuffer);
logger.debug("Wrote {} bytes for NEED_WRAP portion of Handshake", bytesWritten);
logHandshakeStatusBytes(handshakeStatus, "Socket write completed", bytesWritten);
break;
case NEED_UNWRAP:
logger.trace("SSL Handshake with Peer {} Needs Unwrap", peerDescription);
while (sslEngine.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
final boolean decrypted = decrypt(unwrapBuffer);
if (decrypted) {
logger.trace("Decryption was successful for NEED_UNWRAP portion of Handshake");
final SSLEngineResult.HandshakeStatus unwrapHandshakeStatus = sslEngine.getHandshakeStatus();
if (decrypted || unwrapHandshakeStatus == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) {
logHandshakeStatus(unwrapHandshakeStatus, "Decryption completed");
break;
}
if (unwrapBuffer.capacity() - unwrapBuffer.position() < 1) {
logger.trace("Enlarging size of Buffer for NEED_UNWRAP portion of Handshake");
// destinationBuffer is not large enough. Need to increase the size.
logHandshakeStatus(unwrapHandshakeStatus, "Increasing unwrap buffer for decryption");
final ByteBuffer tempBuffer = ByteBuffer.allocate(unwrapBuffer.capacity() + sslEngine.getSession().getApplicationBufferSize());
tempBuffer.put(unwrapBuffer);
unwrapBuffer = tempBuffer;
@ -342,17 +405,37 @@ public class PeerChannel implements Closeable {
continue;
}
logger.trace("Need to read more bytes for NEED_UNWRAP portion of Handshake");
// Need to read more data.
logHandshakeStatus(unwrapHandshakeStatus, "Socket read started");
unwrapBuffer.compact();
final int bytesRead = socketChannel.read(unwrapBuffer);
unwrapBuffer.flip();
logger.debug("Read {} bytes for NEED_UNWRAP portion of Handshake", bytesRead);
logHandshakeStatusBytes(unwrapHandshakeStatus, "Socket read completed", bytesRead);
}
break;
}
}
}
private void logOperationResult(final String operation, final SSLEngineResult sslEngineResult) {
logger.trace("Channel [{}] {} [{}]", peerDescription, operation, sslEngineResult);
}
private void logHandshakeCompleted() {
final SSLSession sslSession = sslEngine.getSession();
logger.debug("Channel [{}] Handshake Completed Protocol [{}] Cipher Suite [{}]", peerDescription, sslSession.getProtocol(), sslSession.getCipherSuite());
}
private void logHandshakeStatus(final SSLEngineResult.HandshakeStatus handshakeStatus) {
logger.debug("Channel [{}] Handshake Status [{}]", peerDescription, handshakeStatus);
}
private void logHandshakeStatus(final SSLEngineResult.HandshakeStatus handshakeStatus, final String operation) {
logger.debug("Channel [{}] Handshake Status [{}] {}", peerDescription, handshakeStatus, operation);
}
private void logHandshakeStatusBytes(final SSLEngineResult.HandshakeStatus handshakeStatus, final String operation, final int bytes) {
logger.debug("Channel [{}] Handshake Status [{}] {} Bytes [{}]", peerDescription, handshakeStatus, operation, bytes);
}
}

View File

@ -0,0 +1,254 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.controller.queue.clustered.client.async.nio;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.ssl.SslHandler;
import org.apache.nifi.remote.io.socket.NetworkUtils;
import org.apache.nifi.security.util.SslContextFactory;
import org.apache.nifi.security.util.TemporaryKeyStoreBuilder;
import org.apache.nifi.security.util.TlsConfiguration;
import org.apache.nifi.security.util.TlsPlatform;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.condition.EnabledIf;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketException;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;
import java.security.GeneralSecurityException;
import java.util.OptionalInt;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestPeerChannel {
private static final String LOCALHOST = "localhost";
private static final int GROUP_THREADS = 1;
private static final boolean CLIENT_CHANNEL = true;
private static final boolean SERVER_CHANNEL = false;
private static final long READ_SLEEP_INTERVAL = 500;
private static final int CHANNEL_TIMEOUT = 15000;
private static final int SOCKET_TIMEOUT = 5000;
private static final long SHUTDOWN_TIMEOUT = 100;
private static final String TLS_1_3 = "TLSv1.3";
private static final String TLS_1_2 = "TLSv1.2";
private static final String TLS_1_3_SUPPORTED = "isTls13Supported";
private static final int PROTOCOL_VERSION = 1;
private static final int VERSION_ACCEPTED = 0x10;
private static SSLContext sslContext;
public static boolean isTls13Supported() {
return TlsPlatform.getSupportedProtocols().contains(TLS_1_3);
}
@BeforeAll
public static void setConfiguration() throws GeneralSecurityException {
final TlsConfiguration tlsConfiguration = new TemporaryKeyStoreBuilder().build();
sslContext = SslContextFactory.createSslContext(tlsConfiguration);
}
@Test
@Timeout(value = CHANNEL_TIMEOUT, unit = TimeUnit.MILLISECONDS)
public void testConnectedClose() throws IOException {
final String enabledProtocol = getEnabledProtocol();
processChannel(enabledProtocol, peerChannel -> {});
}
@Test
@Timeout(value = CHANNEL_TIMEOUT, unit = TimeUnit.MILLISECONDS)
public void testConnectedWriteReadCloseTls12() throws IOException {
assertWriteReadSuccess(TLS_1_2);
}
@EnabledIf(TLS_1_3_SUPPORTED)
@Test
@Timeout(value = CHANNEL_TIMEOUT, unit = TimeUnit.MILLISECONDS)
public void testConnectedWriteReadCloseTls13() throws IOException {
assertWriteReadSuccess(TLS_1_3);
}
private void assertWriteReadSuccess(final String enabledProtocol) throws IOException {
processChannel(enabledProtocol, peerChannel -> {
try {
peerChannel.performHandshake();
final byte[] version = new byte[]{PROTOCOL_VERSION};
final ByteBuffer versionBuffer = ByteBuffer.wrap(version);
final ByteBuffer encryptedVersionBuffer = peerChannel.prepareForWrite(versionBuffer);
peerChannel.write(encryptedVersionBuffer);
final int firstByteRead = read(peerChannel);
assertEquals(PROTOCOL_VERSION, firstByteRead, "Peer Channel first byte read not matched");
final byte[] versionAccepted = new byte[]{VERSION_ACCEPTED};
final ByteBuffer versionAcceptedBuffer = ByteBuffer.wrap(versionAccepted);
final ByteBuffer encryptedVersionAcceptedBuffer = peerChannel.prepareForWrite(versionAcceptedBuffer);
peerChannel.write(encryptedVersionAcceptedBuffer);
final int secondByteRead = read(peerChannel);
assertEquals(VERSION_ACCEPTED, secondByteRead, "Peer Channel second byte read not matched");
} catch (final IOException e) {
throw new UncheckedIOException(String.format("Channel Failed for %s", enabledProtocol), e);
}
});
}
private int read(final PeerChannel peerChannel) throws IOException {
OptionalInt read = peerChannel.read();
while (!read.isPresent()) {
try {
TimeUnit.MILLISECONDS.sleep(READ_SLEEP_INTERVAL);
} catch (InterruptedException e) {
throw new RuntimeException("Peer Channel read sleep interrupted", e);
}
read = peerChannel.read();
}
return read.getAsInt();
}
private void processChannel(final String enabledProtocol, final Consumer<PeerChannel> channelConsumer) throws IOException {
final EventLoopGroup group = new NioEventLoopGroup(GROUP_THREADS);
try (final SocketChannel socketChannel = SocketChannel.open()) {
final Socket socket = socketChannel.socket();
socket.setSoTimeout(SOCKET_TIMEOUT);
final InetSocketAddress serverSocketAddress = getServerSocketAddress();
startServer(group, serverSocketAddress.getPort(), enabledProtocol);
socketChannel.connect(serverSocketAddress);
final SSLEngine sslEngine = createSslEngine(enabledProtocol, CLIENT_CHANNEL);
final PeerChannel peerChannel = new PeerChannel(socketChannel, sslEngine, serverSocketAddress.toString());
assertConnectedOpen(peerChannel);
socketChannel.configureBlocking(false);
channelConsumer.accept(peerChannel);
peerChannel.close();
assertNotConnectedNotOpen(peerChannel);
} finally {
shutdownGroup(group);
}
}
private void assertConnectedOpen(final PeerChannel peerChannel) {
assertTrue(peerChannel.isConnected(), "Channel not connected");
assertTrue(peerChannel.isOpen(), "Channel not open");
}
private void assertNotConnectedNotOpen(final PeerChannel peerChannel) {
assertFalse(peerChannel.isConnected(), "Channel connected");
assertFalse(peerChannel.isOpen(), "Channel open");
}
private void startServer(final EventLoopGroup group, final int port, final String enabledProtocol) {
final ServerBootstrap bootstrap = new ServerBootstrap();
bootstrap.group(group);
bootstrap.channel(NioServerSocketChannel.class);
bootstrap.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(final Channel channel) {
final ChannelPipeline pipeline = channel.pipeline();
final SSLEngine sslEngine = createSslEngine(enabledProtocol, SERVER_CHANNEL);
setPipelineHandlers(pipeline, sslEngine);
pipeline.addLast(new SimpleChannelInboundHandler<ByteBuf>() {
private int protocolVersion;
@Override
protected void channelRead0(ChannelHandlerContext channelHandlerContext, ByteBuf byteBuf) throws Exception {
if (byteBuf.readableBytes() == 1) {
final int read = byteBuf.readByte();
if (PROTOCOL_VERSION == read) {
protocolVersion = read;
channelHandlerContext.writeAndFlush(Unpooled.wrappedBuffer(new byte[]{PROTOCOL_VERSION}));
} else if (protocolVersion == PROTOCOL_VERSION) {
channelHandlerContext.writeAndFlush(Unpooled.wrappedBuffer(new byte[]{VERSION_ACCEPTED}));
} else {
throw new SocketException(String.format("Unexpected Integer [%d] read", read));
}
}
}
});
}
});
final ChannelFuture bindFuture = bootstrap.bind(LOCALHOST, port);
bindFuture.syncUninterruptibly();
}
private SSLEngine createSslEngine(final String enabledProtocol, final boolean useClientMode) {
final SSLEngine sslEngine = sslContext.createSSLEngine();
sslEngine.setUseClientMode(useClientMode);
sslEngine.setEnabledProtocols(new String[]{enabledProtocol});
return sslEngine;
}
private void setPipelineHandlers(final ChannelPipeline pipeline, final SSLEngine sslEngine) {
pipeline.addLast(new SslHandler(sslEngine));
}
private void shutdownGroup(final EventLoopGroup group) {
group.shutdownGracefully(SHUTDOWN_TIMEOUT, SHUTDOWN_TIMEOUT, TimeUnit.MILLISECONDS).syncUninterruptibly();
}
private InetSocketAddress getServerSocketAddress() {
final int port = NetworkUtils.getAvailableTcpPort();
return new InetSocketAddress(LOCALHOST, port);
}
private String getEnabledProtocol() {
return isTls13Supported() ? TLS_1_3 : TLS_1_2;
}
}