NIFI-9253 Corrected SSLSocketChannel.available() for TLSv1.3

- Added unit tests to reproduce issues with available() method
- Changed available() to return size of application buffer
- Removed unused isDataAvailable()
- Refactored unwrap handling to read from channel for buffer underflow

Signed-off-by: Pierre Villard <pierre.villard.fr@gmail.com>

This closes #5421.
This commit is contained in:
exceptionfactory 2021-09-28 17:00:17 -05:00 committed by Pierre Villard
parent ae0154de5a
commit defea61075
No known key found for this signature in database
GPG Key ID: F92A93B30C07C6D5
3 changed files with 168 additions and 149 deletions

View File

@ -30,6 +30,7 @@ import javax.net.ssl.SSLException;
import javax.net.ssl.SSLHandshakeException;
import javax.net.ssl.SSLSession;
import java.io.Closeable;
import java.io.EOFException;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
@ -47,6 +48,7 @@ import java.util.concurrent.TimeUnit;
public class SSLSocketChannel implements Closeable {
private static final Logger LOGGER = LoggerFactory.getLogger(SSLSocketChannel.class);
private static final int MINIMUM_READ_BUFFER_SIZE = 1;
private static final int DISCARD_BUFFER_LENGTH = 8192;
private static final int END_OF_STREAM = -1;
private static final byte[] EMPTY_MESSAGE = new byte[0];
@ -266,7 +268,7 @@ public class SSLSocketChannel implements Closeable {
status = wrapResult.getStatus();
}
if (Status.CLOSED == status) {
final ByteBuffer streamOutputBuffer = streamOutManager.prepareForRead(1);
final ByteBuffer streamOutputBuffer = streamOutManager.prepareForRead(MINIMUM_READ_BUFFER_SIZE);
try {
writeChannel(streamOutputBuffer);
} catch (final IOException e) {
@ -291,39 +293,8 @@ public class SSLSocketChannel implements Closeable {
* @throws IOException Thrown on failures checking for available bytes
*/
public int available() throws IOException {
ByteBuffer appDataBuffer = appDataManager.prepareForRead(1);
ByteBuffer streamDataBuffer = streamInManager.prepareForRead(1);
final int buffered = appDataBuffer.remaining() + streamDataBuffer.remaining();
if (buffered > 0) {
return buffered;
}
if (!isDataAvailable()) {
return 0;
}
appDataBuffer = appDataManager.prepareForRead(1);
streamDataBuffer = streamInManager.prepareForRead(1);
return appDataBuffer.remaining() + streamDataBuffer.remaining();
}
/**
* Is data available for reading
*
* @return Data available status
* @throws IOException Thrown on SocketChannel.read() failures
*/
public boolean isDataAvailable() throws IOException {
final ByteBuffer appDataBuffer = appDataManager.prepareForRead(1);
final ByteBuffer streamDataBuffer = streamInManager.prepareForRead(1);
if (appDataBuffer.remaining() > 0 || streamDataBuffer.remaining() > 0) {
return true;
}
final ByteBuffer writableBuffer = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
final int bytesRead = channel.read(writableBuffer);
return (bytesRead > 0);
final ByteBuffer appDataBuffer = appDataManager.prepareForRead(MINIMUM_READ_BUFFER_SIZE);
return appDataBuffer.remaining();
}
/**
@ -373,42 +344,24 @@ public class SSLSocketChannel implements Closeable {
}
appDataManager.clear();
while (true) {
final SSLEngineResult unwrapResult = unwrap();
if (SSLEngineResult.HandshakeStatus.FINISHED == unwrapResult.getHandshakeStatus()) {
// RFC 8446 Section 4.6 describes Post-Handshake Messages for TLS 1.3
logOperation("Processing Post-Handshake Messages");
continue;
final SSLEngineResult unwrapResult = unwrapBufferReadChannel();
final Status status = unwrapResult.getStatus();
if (Status.CLOSED == status) {
applicationBytesRead = readApplicationBuffer(buffer, offset, len);
if (applicationBytesRead == 0) {
return END_OF_STREAM;
}
final Status status = unwrapResult.getStatus();
switch (status) {
case BUFFER_OVERFLOW:
throw new IllegalStateException(String.format("SSLEngineResult Status [%s] not allowed from unwrap", status));
case BUFFER_UNDERFLOW:
final ByteBuffer streamBuffer = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
final int channelBytesRead = readChannel(streamBuffer);
logOperationBytes("Channel Read Completed", channelBytesRead);
if (channelBytesRead == END_OF_STREAM) {
return END_OF_STREAM;
}
break;
case CLOSED:
applicationBytesRead = readApplicationBuffer(buffer, offset, len);
if (applicationBytesRead == 0) {
return END_OF_STREAM;
}
streamInManager.compact();
return applicationBytesRead;
case OK:
applicationBytesRead = readApplicationBuffer(buffer, offset, len);
if (applicationBytesRead == 0) {
throw new IOException("Read Application Buffer Failed");
}
streamInManager.compact();
return applicationBytesRead;
streamInManager.compact();
return applicationBytesRead;
} else if (Status.OK == status) {
applicationBytesRead = readApplicationBuffer(buffer, offset, len);
if (applicationBytesRead == 0) {
throw new IOException("Read Application Buffer Failed");
}
streamInManager.compact();
return applicationBytesRead;
} else {
throw new IllegalStateException(String.format("SSLEngineResult Status [%s] not expected from unwrap", status));
}
}
@ -508,24 +461,13 @@ public class SSLSocketChannel implements Closeable {
handshakeStatus = engine.getHandshakeStatus();
break;
case NEED_UNWRAP:
final SSLEngineResult unwrapResult = unwrap();
final SSLEngineResult unwrapResult = unwrapBufferReadChannel();
handshakeStatus = unwrapResult.getHandshakeStatus();
Status unwrapResultStatus = unwrapResult.getStatus();
if (unwrapResultStatus == Status.BUFFER_UNDERFLOW) {
final ByteBuffer writableDataIn = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
final int bytesRead = readChannel(writableDataIn);
logOperationBytes("Handshake Channel Read", bytesRead);
if (bytesRead == END_OF_STREAM) {
throw getHandshakeException(handshakeStatus, "End of Stream Found");
}
} else if (unwrapResultStatus == Status.CLOSED) {
if (unwrapResult.getStatus() == Status.CLOSED) {
throw getHandshakeException(handshakeStatus, "Channel Closed");
} else {
streamInManager.compact();
appDataManager.clear();
}
streamInManager.compact();
appDataManager.clear();
break;
case NEED_WRAP:
final ByteBuffer outboundBuffer = streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
@ -536,7 +478,7 @@ public class SSLSocketChannel implements Closeable {
if (wrapResultStatus == Status.BUFFER_OVERFLOW) {
streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
} else if (wrapResultStatus == Status.OK) {
final ByteBuffer streamBuffer = streamOutManager.prepareForRead(1);
final ByteBuffer streamBuffer = streamOutManager.prepareForRead(MINIMUM_READ_BUFFER_SIZE);
final int bytesRemaining = streamBuffer.remaining();
writeChannel(streamBuffer);
logOperationBytes("Handshake Channel Write Completed", bytesRemaining);
@ -549,8 +491,29 @@ public class SSLSocketChannel implements Closeable {
}
}
private int readChannel(final ByteBuffer outputBuffer) throws IOException {
private SSLEngineResult unwrapBufferReadChannel() throws IOException {
SSLEngineResult unwrapResult = unwrap();
while (Status.BUFFER_UNDERFLOW == unwrapResult.getStatus()) {
final int channelBytesRead = readChannel();
if (channelBytesRead == END_OF_STREAM) {
throw new EOFException("End of Stream found for Channel Read");
}
unwrapResult = unwrap();
if (SSLEngineResult.HandshakeStatus.FINISHED == unwrapResult.getHandshakeStatus()) {
// RFC 8446 Section 4.6 describes Post-Handshake Messages for TLS 1.3
logOperation("Processing Post-Handshake Messages");
unwrapResult = unwrap();
}
}
return unwrapResult;
}
private int readChannel() throws IOException {
logOperation("Channel Read Started");
final ByteBuffer outputBuffer = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
final long started = System.currentTimeMillis();
long sleepNanoseconds = INITIAL_INCREMENTAL_SLEEP;
@ -568,10 +531,24 @@ public class SSLSocketChannel implements Closeable {
continue;
}
logOperationBytes("Channel Read Completed", channelBytesRead);
return channelBytesRead;
}
}
private void readChannelDiscard() {
try {
final ByteBuffer readBuffer = ByteBuffer.allocate(DISCARD_BUFFER_LENGTH);
int bytesRead = channel.read(readBuffer);
while (bytesRead > 0) {
readBuffer.clear();
bytesRead = channel.read(readBuffer);
}
} catch (final IOException e) {
LOGGER.debug("[{}:{}] Read Channel Discard Failed", remoteAddress, port, e);
}
}
private void writeChannel(final ByteBuffer inputBuffer) throws IOException {
long lastWriteCompleted = System.currentTimeMillis();
@ -605,19 +582,6 @@ public class SSLSocketChannel implements Closeable {
return Math.min(nanoseconds * 2, BUFFER_FULL_EMPTY_WAIT_NANOS);
}
private void readChannelDiscard() {
try {
final ByteBuffer readBuffer = ByteBuffer.allocate(DISCARD_BUFFER_LENGTH);
int bytesRead = channel.read(readBuffer);
while (bytesRead > 0) {
readBuffer.clear();
bytesRead = channel.read(readBuffer);
}
} catch (final IOException e) {
LOGGER.debug("[{}:{}] Read Channel Discard Failed", remoteAddress, port, e);
}
}
private int readApplicationBuffer(final byte[] buffer, final int offset, final int len) {
logOperationBytes("Application Buffer Read Requested", len);
final ByteBuffer appDataBuffer = appDataManager.prepareForRead(len);

View File

@ -58,8 +58,4 @@ public class SSLSocketChannelInputStream extends InputStream {
public int available() throws IOException {
return channel.available();
}
public boolean isDataAvailable() throws IOException {
return available() > 0;
}
}

View File

@ -38,9 +38,10 @@ import org.apache.nifi.security.util.SslContextFactory;
import org.apache.nifi.security.util.TemporaryKeyStoreBuilder;
import org.apache.nifi.security.util.TlsConfiguration;
import org.apache.nifi.security.util.TlsPlatform;
import org.junit.Assume;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.condition.EnabledIf;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
@ -55,17 +56,19 @@ import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.security.GeneralSecurityException;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
@Timeout(value = 15)
public class SSLSocketChannelTest {
private static final String LOCALHOST = "localhost";
@ -81,10 +84,10 @@ public class SSLSocketChannelTest {
private static final int CHANNEL_POLL_TIMEOUT = 5000;
private static final long CHANNEL_SLEEP_BEFORE_READ = 100;
private static final int MAX_MESSAGE_LENGTH = 1024;
private static final long SHUTDOWN_TIMEOUT = 100;
private static final String TLS_1_3 = "TLSv1.3";
private static final String TLS_1_2 = "TLSv1.2";
@ -97,9 +100,17 @@ public class SSLSocketChannelTest {
private static final int FIRST_BYTE_OFFSET = 1;
private static final int SINGLE_COUNT_DOWN = 1;
private static SSLContext sslContext;
@BeforeClass
private static final String TLS_1_3_SUPPORTED = "isTls13Supported";
public static boolean isTls13Supported() {
return TlsPlatform.getSupportedProtocols().contains(TLS_1_3);
}
@BeforeAll
public static void setConfiguration() throws GeneralSecurityException {
final TlsConfiguration tlsConfiguration = new TemporaryKeyStoreBuilder().build();
sslContext = SslContextFactory.createSslContext(tlsConfiguration);
@ -115,54 +126,60 @@ public class SSLSocketChannelTest {
@Test
public void testClientConnectHandshakeFailed() throws IOException {
assumeProtocolSupported(TLS_1_2);
final String enabledProtocol = isTls13Supported() ? TLS_1_3 : TLS_1_2;
final EventLoopGroup group = new NioEventLoopGroup(GROUP_THREADS);
try (final SocketChannel socketChannel = SocketChannel.open()) {
final int port = NetworkUtils.getAvailableTcpPort();
startServer(group, port, TLS_1_2);
startServer(group, port, enabledProtocol, getSingleCountDownLatch());
socketChannel.connect(new InetSocketAddress(LOCALHOST, port));
final SSLEngine sslEngine = createSslEngine(TLS_1_2, CLIENT_CHANNEL);
final SSLEngine sslEngine = createSslEngine(enabledProtocol, CLIENT_CHANNEL);
final SSLSocketChannel sslSocketChannel = new SSLSocketChannel(sslEngine, socketChannel);
sslSocketChannel.setTimeout(CHANNEL_FAILURE_TIMEOUT);
group.shutdownGracefully().syncUninterruptibly();
shutdownGroup(group);
assertThrows(SSLException.class, sslSocketChannel::connect);
} finally {
group.shutdownGracefully().syncUninterruptibly();
shutdownGroup(group);
}
}
@Test
public void testClientConnectWriteReadTls12() throws Exception {
assumeProtocolSupported(TLS_1_2);
assertChannelConnectedWriteReadClosed(TLS_1_2);
}
@EnabledIf(TLS_1_3_SUPPORTED)
@Test
public void testClientConnectWriteReadTls13() throws Exception {
assumeProtocolSupported(TLS_1_3);
assertChannelConnectedWriteReadClosed(TLS_1_3);
}
@Test(timeout = CHANNEL_TIMEOUT)
@Test
public void testClientConnectWriteAvailableReadTls12() throws Exception {
assertChannelConnectedWriteAvailableRead(TLS_1_2);
}
@EnabledIf(TLS_1_3_SUPPORTED)
@Test
public void testClientConnectWriteAvailableReadTls13() throws Exception {
assertChannelConnectedWriteAvailableRead(TLS_1_3);
}
@Test
public void testServerReadWriteTls12() throws Exception {
assumeProtocolSupported(TLS_1_2);
assertServerChannelConnectedReadClosed(TLS_1_2);
}
@Test(timeout = CHANNEL_TIMEOUT)
@EnabledIf(TLS_1_3_SUPPORTED)
@Test
public void testServerReadWriteTls13() throws Exception {
assumeProtocolSupported(TLS_1_3);
assertServerChannelConnectedReadClosed(TLS_1_3);
}
private void assumeProtocolSupported(final String protocol) {
Assume.assumeTrue(String.format("Protocol [%s] not supported", protocol), TlsPlatform.getSupportedProtocols().contains(protocol));
}
private void assertServerChannelConnectedReadClosed(final String enabledProtocol) throws IOException, InterruptedException {
final int port = NetworkUtils.getAvailableTcpPort();
final ServerSocketChannel serverSocketChannel = ServerSocketChannel.open();
@ -194,67 +211,100 @@ public class SSLSocketChannelTest {
channel.writeAndFlush(MESSAGE).syncUninterruptibly();
final String messageRead = queue.poll(CHANNEL_POLL_TIMEOUT, TimeUnit.MILLISECONDS);
assertEquals("Message not matched", MESSAGE, messageRead);
assertEquals(MESSAGE, messageRead, "Message not matched");
} finally {
channel.close();
}
} finally {
group.shutdownGracefully().syncUninterruptibly();
shutdownGroup(group);
serverSocketChannel.close();
}
}
private void assertChannelConnectedWriteReadClosed(final String enabledProtocol) throws IOException {
processClientSslSocketChannel(enabledProtocol, (sslSocketChannel -> {
final CountDownLatch countDownLatch = getSingleCountDownLatch();
processClientSslSocketChannel(enabledProtocol, countDownLatch, (sslSocketChannel -> {
try {
sslSocketChannel.connect();
assertFalse("Channel closed", sslSocketChannel.isClosed());
assertFalse(sslSocketChannel.isClosed());
assertChannelWriteRead(sslSocketChannel);
assertChannelWriteRead(sslSocketChannel, countDownLatch);
sslSocketChannel.close();
assertTrue("Channel not closed", sslSocketChannel.isClosed());
assertTrue(sslSocketChannel.isClosed());
} catch (final IOException e) {
throw new UncheckedIOException(String.format("Channel Failed for %s", enabledProtocol), e);
}
}));
}
private void assertChannelWriteRead(final SSLSocketChannel sslSocketChannel) throws IOException {
sslSocketChannel.write(MESSAGE_BYTES);
while (sslSocketChannel.available() == 0) {
private void assertChannelConnectedWriteAvailableRead(final String enabledProtocol) throws IOException {
final CountDownLatch countDownLatch = getSingleCountDownLatch();
processClientSslSocketChannel(enabledProtocol, countDownLatch, (sslSocketChannel -> {
try {
TimeUnit.MILLISECONDS.sleep(CHANNEL_SLEEP_BEFORE_READ);
} catch (final InterruptedException e) {
throw new RuntimeException(e);
}
}
sslSocketChannel.connect();
assertFalse(sslSocketChannel.isClosed());
assertChannelWriteAvailableRead(sslSocketChannel, countDownLatch);
sslSocketChannel.close();
assertTrue(sslSocketChannel.isClosed());
} catch (final IOException e) {
throw new UncheckedIOException(String.format("Channel Failed for %s", enabledProtocol), e);
}
}));
}
private void assertChannelWriteAvailableRead(final SSLSocketChannel sslSocketChannel, final CountDownLatch countDownLatch) throws IOException {
sslSocketChannel.write(MESSAGE_BYTES);
sslSocketChannel.available();
awaitCountDownLatch(countDownLatch);
assetMessageRead(sslSocketChannel);
}
private void assertChannelWriteRead(final SSLSocketChannel sslSocketChannel, final CountDownLatch countDownLatch) throws IOException {
sslSocketChannel.write(MESSAGE_BYTES);
awaitCountDownLatch(countDownLatch);
assetMessageRead(sslSocketChannel);
}
private void awaitCountDownLatch(final CountDownLatch countDownLatch) throws IOException {
try {
countDownLatch.await();
} catch (final InterruptedException e) {
throw new IOException("Count Down Interrupted", e);
}
}
private void assetMessageRead(final SSLSocketChannel sslSocketChannel) throws IOException {
final byte firstByteRead = (byte) sslSocketChannel.read();
assertEquals("Channel Message first byte not matched", MESSAGE_BYTES[0], firstByteRead);
assertEquals(MESSAGE_BYTES[0], firstByteRead, "Channel Message first byte not matched");
final int available = sslSocketChannel.available();
final int availableExpected = MESSAGE_BYTES.length - FIRST_BYTE_OFFSET;
assertEquals(availableExpected, available, "Available Bytes not matched");
final byte[] messageBytes = new byte[MESSAGE_BYTES.length];
messageBytes[0] = firstByteRead;
final int messageBytesRead = sslSocketChannel.read(messageBytes, FIRST_BYTE_OFFSET, messageBytes.length);
assertEquals("Channel Message Bytes Read not matched", messageBytes.length - FIRST_BYTE_OFFSET, messageBytesRead);
assertEquals(messageBytes.length - FIRST_BYTE_OFFSET, messageBytesRead, "Channel Message Bytes Read not matched");
final String message = new String(messageBytes, MESSAGE_CHARSET);
assertEquals("Channel Message not matched", MESSAGE, message);
assertEquals(MESSAGE, message, "Message not matched");
}
private void processClientSslSocketChannel(final String enabledProtocol, final Consumer<SSLSocketChannel> channelConsumer) throws IOException {
private void processClientSslSocketChannel(final String enabledProtocol, final CountDownLatch countDownLatch, final Consumer<SSLSocketChannel> channelConsumer) throws IOException {
final EventLoopGroup group = new NioEventLoopGroup(GROUP_THREADS);
try {
final int port = NetworkUtils.getAvailableTcpPort();
startServer(group, port, enabledProtocol);
startServer(group, port, enabledProtocol, countDownLatch);
final SSLSocketChannel sslSocketChannel = new SSLSocketChannel(sslContext, LOCALHOST, port, null, CLIENT_CHANNEL);
sslSocketChannel.setTimeout(CHANNEL_TIMEOUT);
channelConsumer.accept(sslSocketChannel);
} finally {
group.shutdownGracefully().syncUninterruptibly();
shutdownGroup(group);
}
}
@ -273,7 +323,7 @@ public class SSLSocketChannelTest {
return bootstrap.connect(LOCALHOST, port).syncUninterruptibly().channel();
}
private void startServer(final EventLoopGroup group, final int port, final String enabledProtocol) {
private void startServer(final EventLoopGroup group, final int port, final String enabledProtocol, final CountDownLatch countDownLatch) {
final ServerBootstrap bootstrap = new ServerBootstrap();
bootstrap.group(group);
bootstrap.channel(NioServerSocketChannel.class);
@ -287,6 +337,7 @@ public class SSLSocketChannelTest {
@Override
protected void channelRead0(ChannelHandlerContext channelHandlerContext, String s) throws Exception {
channelHandlerContext.channel().writeAndFlush(MESSAGE).sync();
countDownLatch.countDown();
}
});
}
@ -309,4 +360,12 @@ public class SSLSocketChannelTest {
pipeline.addLast(new StringDecoder());
pipeline.addLast(new StringEncoder());
}
private void shutdownGroup(final EventLoopGroup group) {
group.shutdownGracefully(SHUTDOWN_TIMEOUT, SHUTDOWN_TIMEOUT, TimeUnit.MILLISECONDS).syncUninterruptibly();
}
private CountDownLatch getSingleCountDownLatch() {
return new CountDownLatch(SINGLE_COUNT_DOWN);
}
}