NIFI-9253 Corrected SSLSocketChannel.available() for TLSv1.3

- Added unit tests to reproduce issues with available() method
- Changed available() to return size of application buffer
- Removed unused isDataAvailable()
- Refactored unwrap handling to read from channel for buffer underflow

Signed-off-by: Pierre Villard <pierre.villard.fr@gmail.com>

This closes #5421.
This commit is contained in:
exceptionfactory 2021-09-28 17:00:17 -05:00 committed by Pierre Villard
parent ae0154de5a
commit defea61075
No known key found for this signature in database
GPG Key ID: F92A93B30C07C6D5
3 changed files with 168 additions and 149 deletions

View File

@ -30,6 +30,7 @@ import javax.net.ssl.SSLException;
import javax.net.ssl.SSLHandshakeException; import javax.net.ssl.SSLHandshakeException;
import javax.net.ssl.SSLSession; import javax.net.ssl.SSLSession;
import java.io.Closeable; import java.io.Closeable;
import java.io.EOFException;
import java.io.IOException; import java.io.IOException;
import java.net.InetAddress; import java.net.InetAddress;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
@ -47,6 +48,7 @@ import java.util.concurrent.TimeUnit;
public class SSLSocketChannel implements Closeable { public class SSLSocketChannel implements Closeable {
private static final Logger LOGGER = LoggerFactory.getLogger(SSLSocketChannel.class); private static final Logger LOGGER = LoggerFactory.getLogger(SSLSocketChannel.class);
private static final int MINIMUM_READ_BUFFER_SIZE = 1;
private static final int DISCARD_BUFFER_LENGTH = 8192; private static final int DISCARD_BUFFER_LENGTH = 8192;
private static final int END_OF_STREAM = -1; private static final int END_OF_STREAM = -1;
private static final byte[] EMPTY_MESSAGE = new byte[0]; private static final byte[] EMPTY_MESSAGE = new byte[0];
@ -266,7 +268,7 @@ public class SSLSocketChannel implements Closeable {
status = wrapResult.getStatus(); status = wrapResult.getStatus();
} }
if (Status.CLOSED == status) { if (Status.CLOSED == status) {
final ByteBuffer streamOutputBuffer = streamOutManager.prepareForRead(1); final ByteBuffer streamOutputBuffer = streamOutManager.prepareForRead(MINIMUM_READ_BUFFER_SIZE);
try { try {
writeChannel(streamOutputBuffer); writeChannel(streamOutputBuffer);
} catch (final IOException e) { } catch (final IOException e) {
@ -291,39 +293,8 @@ public class SSLSocketChannel implements Closeable {
* @throws IOException Thrown on failures checking for available bytes * @throws IOException Thrown on failures checking for available bytes
*/ */
public int available() throws IOException { public int available() throws IOException {
ByteBuffer appDataBuffer = appDataManager.prepareForRead(1); final ByteBuffer appDataBuffer = appDataManager.prepareForRead(MINIMUM_READ_BUFFER_SIZE);
ByteBuffer streamDataBuffer = streamInManager.prepareForRead(1); return appDataBuffer.remaining();
final int buffered = appDataBuffer.remaining() + streamDataBuffer.remaining();
if (buffered > 0) {
return buffered;
}
if (!isDataAvailable()) {
return 0;
}
appDataBuffer = appDataManager.prepareForRead(1);
streamDataBuffer = streamInManager.prepareForRead(1);
return appDataBuffer.remaining() + streamDataBuffer.remaining();
}
/**
* Is data available for reading
*
* @return Data available status
* @throws IOException Thrown on SocketChannel.read() failures
*/
public boolean isDataAvailable() throws IOException {
final ByteBuffer appDataBuffer = appDataManager.prepareForRead(1);
final ByteBuffer streamDataBuffer = streamInManager.prepareForRead(1);
if (appDataBuffer.remaining() > 0 || streamDataBuffer.remaining() > 0) {
return true;
}
final ByteBuffer writableBuffer = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
final int bytesRead = channel.read(writableBuffer);
return (bytesRead > 0);
} }
/** /**
@ -373,42 +344,24 @@ public class SSLSocketChannel implements Closeable {
} }
appDataManager.clear(); appDataManager.clear();
while (true) { final SSLEngineResult unwrapResult = unwrapBufferReadChannel();
final SSLEngineResult unwrapResult = unwrap(); final Status status = unwrapResult.getStatus();
if (Status.CLOSED == status) {
if (SSLEngineResult.HandshakeStatus.FINISHED == unwrapResult.getHandshakeStatus()) { applicationBytesRead = readApplicationBuffer(buffer, offset, len);
// RFC 8446 Section 4.6 describes Post-Handshake Messages for TLS 1.3 if (applicationBytesRead == 0) {
logOperation("Processing Post-Handshake Messages"); return END_OF_STREAM;
continue;
} }
streamInManager.compact();
final Status status = unwrapResult.getStatus(); return applicationBytesRead;
switch (status) { } else if (Status.OK == status) {
case BUFFER_OVERFLOW: applicationBytesRead = readApplicationBuffer(buffer, offset, len);
throw new IllegalStateException(String.format("SSLEngineResult Status [%s] not allowed from unwrap", status)); if (applicationBytesRead == 0) {
case BUFFER_UNDERFLOW: throw new IOException("Read Application Buffer Failed");
final ByteBuffer streamBuffer = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
final int channelBytesRead = readChannel(streamBuffer);
logOperationBytes("Channel Read Completed", channelBytesRead);
if (channelBytesRead == END_OF_STREAM) {
return END_OF_STREAM;
}
break;
case CLOSED:
applicationBytesRead = readApplicationBuffer(buffer, offset, len);
if (applicationBytesRead == 0) {
return END_OF_STREAM;
}
streamInManager.compact();
return applicationBytesRead;
case OK:
applicationBytesRead = readApplicationBuffer(buffer, offset, len);
if (applicationBytesRead == 0) {
throw new IOException("Read Application Buffer Failed");
}
streamInManager.compact();
return applicationBytesRead;
} }
streamInManager.compact();
return applicationBytesRead;
} else {
throw new IllegalStateException(String.format("SSLEngineResult Status [%s] not expected from unwrap", status));
} }
} }
@ -508,24 +461,13 @@ public class SSLSocketChannel implements Closeable {
handshakeStatus = engine.getHandshakeStatus(); handshakeStatus = engine.getHandshakeStatus();
break; break;
case NEED_UNWRAP: case NEED_UNWRAP:
final SSLEngineResult unwrapResult = unwrap(); final SSLEngineResult unwrapResult = unwrapBufferReadChannel();
handshakeStatus = unwrapResult.getHandshakeStatus(); handshakeStatus = unwrapResult.getHandshakeStatus();
Status unwrapResultStatus = unwrapResult.getStatus(); if (unwrapResult.getStatus() == Status.CLOSED) {
if (unwrapResultStatus == Status.BUFFER_UNDERFLOW) {
final ByteBuffer writableDataIn = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
final int bytesRead = readChannel(writableDataIn);
logOperationBytes("Handshake Channel Read", bytesRead);
if (bytesRead == END_OF_STREAM) {
throw getHandshakeException(handshakeStatus, "End of Stream Found");
}
} else if (unwrapResultStatus == Status.CLOSED) {
throw getHandshakeException(handshakeStatus, "Channel Closed"); throw getHandshakeException(handshakeStatus, "Channel Closed");
} else {
streamInManager.compact();
appDataManager.clear();
} }
streamInManager.compact();
appDataManager.clear();
break; break;
case NEED_WRAP: case NEED_WRAP:
final ByteBuffer outboundBuffer = streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize()); final ByteBuffer outboundBuffer = streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
@ -536,7 +478,7 @@ public class SSLSocketChannel implements Closeable {
if (wrapResultStatus == Status.BUFFER_OVERFLOW) { if (wrapResultStatus == Status.BUFFER_OVERFLOW) {
streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize()); streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
} else if (wrapResultStatus == Status.OK) { } else if (wrapResultStatus == Status.OK) {
final ByteBuffer streamBuffer = streamOutManager.prepareForRead(1); final ByteBuffer streamBuffer = streamOutManager.prepareForRead(MINIMUM_READ_BUFFER_SIZE);
final int bytesRemaining = streamBuffer.remaining(); final int bytesRemaining = streamBuffer.remaining();
writeChannel(streamBuffer); writeChannel(streamBuffer);
logOperationBytes("Handshake Channel Write Completed", bytesRemaining); logOperationBytes("Handshake Channel Write Completed", bytesRemaining);
@ -549,8 +491,29 @@ public class SSLSocketChannel implements Closeable {
} }
} }
private int readChannel(final ByteBuffer outputBuffer) throws IOException { private SSLEngineResult unwrapBufferReadChannel() throws IOException {
SSLEngineResult unwrapResult = unwrap();
while (Status.BUFFER_UNDERFLOW == unwrapResult.getStatus()) {
final int channelBytesRead = readChannel();
if (channelBytesRead == END_OF_STREAM) {
throw new EOFException("End of Stream found for Channel Read");
}
unwrapResult = unwrap();
if (SSLEngineResult.HandshakeStatus.FINISHED == unwrapResult.getHandshakeStatus()) {
// RFC 8446 Section 4.6 describes Post-Handshake Messages for TLS 1.3
logOperation("Processing Post-Handshake Messages");
unwrapResult = unwrap();
}
}
return unwrapResult;
}
private int readChannel() throws IOException {
logOperation("Channel Read Started"); logOperation("Channel Read Started");
final ByteBuffer outputBuffer = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
final long started = System.currentTimeMillis(); final long started = System.currentTimeMillis();
long sleepNanoseconds = INITIAL_INCREMENTAL_SLEEP; long sleepNanoseconds = INITIAL_INCREMENTAL_SLEEP;
@ -568,10 +531,24 @@ public class SSLSocketChannel implements Closeable {
continue; continue;
} }
logOperationBytes("Channel Read Completed", channelBytesRead);
return channelBytesRead; return channelBytesRead;
} }
} }
private void readChannelDiscard() {
try {
final ByteBuffer readBuffer = ByteBuffer.allocate(DISCARD_BUFFER_LENGTH);
int bytesRead = channel.read(readBuffer);
while (bytesRead > 0) {
readBuffer.clear();
bytesRead = channel.read(readBuffer);
}
} catch (final IOException e) {
LOGGER.debug("[{}:{}] Read Channel Discard Failed", remoteAddress, port, e);
}
}
private void writeChannel(final ByteBuffer inputBuffer) throws IOException { private void writeChannel(final ByteBuffer inputBuffer) throws IOException {
long lastWriteCompleted = System.currentTimeMillis(); long lastWriteCompleted = System.currentTimeMillis();
@ -605,19 +582,6 @@ public class SSLSocketChannel implements Closeable {
return Math.min(nanoseconds * 2, BUFFER_FULL_EMPTY_WAIT_NANOS); return Math.min(nanoseconds * 2, BUFFER_FULL_EMPTY_WAIT_NANOS);
} }
private void readChannelDiscard() {
try {
final ByteBuffer readBuffer = ByteBuffer.allocate(DISCARD_BUFFER_LENGTH);
int bytesRead = channel.read(readBuffer);
while (bytesRead > 0) {
readBuffer.clear();
bytesRead = channel.read(readBuffer);
}
} catch (final IOException e) {
LOGGER.debug("[{}:{}] Read Channel Discard Failed", remoteAddress, port, e);
}
}
private int readApplicationBuffer(final byte[] buffer, final int offset, final int len) { private int readApplicationBuffer(final byte[] buffer, final int offset, final int len) {
logOperationBytes("Application Buffer Read Requested", len); logOperationBytes("Application Buffer Read Requested", len);
final ByteBuffer appDataBuffer = appDataManager.prepareForRead(len); final ByteBuffer appDataBuffer = appDataManager.prepareForRead(len);

View File

@ -58,8 +58,4 @@ public class SSLSocketChannelInputStream extends InputStream {
public int available() throws IOException { public int available() throws IOException {
return channel.available(); return channel.available();
} }
public boolean isDataAvailable() throws IOException {
return available() > 0;
}
} }

View File

@ -38,9 +38,10 @@ import org.apache.nifi.security.util.SslContextFactory;
import org.apache.nifi.security.util.TemporaryKeyStoreBuilder; import org.apache.nifi.security.util.TemporaryKeyStoreBuilder;
import org.apache.nifi.security.util.TlsConfiguration; import org.apache.nifi.security.util.TlsConfiguration;
import org.apache.nifi.security.util.TlsPlatform; import org.apache.nifi.security.util.TlsPlatform;
import org.junit.Assume; import org.junit.jupiter.api.BeforeAll;
import org.junit.BeforeClass; import org.junit.jupiter.api.Test;
import org.junit.Test; import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.condition.EnabledIf;
import javax.net.ssl.SSLContext; import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngine;
@ -55,17 +56,19 @@ import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.security.GeneralSecurityException; import java.security.GeneralSecurityException;
import java.util.concurrent.BlockingQueue; import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.function.Consumer; import java.util.function.Consumer;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.Assert.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@Timeout(value = 15)
public class SSLSocketChannelTest { public class SSLSocketChannelTest {
private static final String LOCALHOST = "localhost"; private static final String LOCALHOST = "localhost";
@ -81,10 +84,10 @@ public class SSLSocketChannelTest {
private static final int CHANNEL_POLL_TIMEOUT = 5000; private static final int CHANNEL_POLL_TIMEOUT = 5000;
private static final long CHANNEL_SLEEP_BEFORE_READ = 100;
private static final int MAX_MESSAGE_LENGTH = 1024; private static final int MAX_MESSAGE_LENGTH = 1024;
private static final long SHUTDOWN_TIMEOUT = 100;
private static final String TLS_1_3 = "TLSv1.3"; private static final String TLS_1_3 = "TLSv1.3";
private static final String TLS_1_2 = "TLSv1.2"; private static final String TLS_1_2 = "TLSv1.2";
@ -97,9 +100,17 @@ public class SSLSocketChannelTest {
private static final int FIRST_BYTE_OFFSET = 1; private static final int FIRST_BYTE_OFFSET = 1;
private static final int SINGLE_COUNT_DOWN = 1;
private static SSLContext sslContext; private static SSLContext sslContext;
@BeforeClass private static final String TLS_1_3_SUPPORTED = "isTls13Supported";
public static boolean isTls13Supported() {
return TlsPlatform.getSupportedProtocols().contains(TLS_1_3);
}
@BeforeAll
public static void setConfiguration() throws GeneralSecurityException { public static void setConfiguration() throws GeneralSecurityException {
final TlsConfiguration tlsConfiguration = new TemporaryKeyStoreBuilder().build(); final TlsConfiguration tlsConfiguration = new TemporaryKeyStoreBuilder().build();
sslContext = SslContextFactory.createSslContext(tlsConfiguration); sslContext = SslContextFactory.createSslContext(tlsConfiguration);
@ -115,54 +126,60 @@ public class SSLSocketChannelTest {
@Test @Test
public void testClientConnectHandshakeFailed() throws IOException { public void testClientConnectHandshakeFailed() throws IOException {
assumeProtocolSupported(TLS_1_2); final String enabledProtocol = isTls13Supported() ? TLS_1_3 : TLS_1_2;
final EventLoopGroup group = new NioEventLoopGroup(GROUP_THREADS); final EventLoopGroup group = new NioEventLoopGroup(GROUP_THREADS);
try (final SocketChannel socketChannel = SocketChannel.open()) { try (final SocketChannel socketChannel = SocketChannel.open()) {
final int port = NetworkUtils.getAvailableTcpPort(); final int port = NetworkUtils.getAvailableTcpPort();
startServer(group, port, TLS_1_2); startServer(group, port, enabledProtocol, getSingleCountDownLatch());
socketChannel.connect(new InetSocketAddress(LOCALHOST, port)); socketChannel.connect(new InetSocketAddress(LOCALHOST, port));
final SSLEngine sslEngine = createSslEngine(TLS_1_2, CLIENT_CHANNEL); final SSLEngine sslEngine = createSslEngine(enabledProtocol, CLIENT_CHANNEL);
final SSLSocketChannel sslSocketChannel = new SSLSocketChannel(sslEngine, socketChannel); final SSLSocketChannel sslSocketChannel = new SSLSocketChannel(sslEngine, socketChannel);
sslSocketChannel.setTimeout(CHANNEL_FAILURE_TIMEOUT); sslSocketChannel.setTimeout(CHANNEL_FAILURE_TIMEOUT);
group.shutdownGracefully().syncUninterruptibly(); shutdownGroup(group);
assertThrows(SSLException.class, sslSocketChannel::connect); assertThrows(SSLException.class, sslSocketChannel::connect);
} finally { } finally {
group.shutdownGracefully().syncUninterruptibly(); shutdownGroup(group);
} }
} }
@Test @Test
public void testClientConnectWriteReadTls12() throws Exception { public void testClientConnectWriteReadTls12() throws Exception {
assumeProtocolSupported(TLS_1_2);
assertChannelConnectedWriteReadClosed(TLS_1_2); assertChannelConnectedWriteReadClosed(TLS_1_2);
} }
@EnabledIf(TLS_1_3_SUPPORTED)
@Test @Test
public void testClientConnectWriteReadTls13() throws Exception { public void testClientConnectWriteReadTls13() throws Exception {
assumeProtocolSupported(TLS_1_3);
assertChannelConnectedWriteReadClosed(TLS_1_3); assertChannelConnectedWriteReadClosed(TLS_1_3);
} }
@Test(timeout = CHANNEL_TIMEOUT) @Test
public void testClientConnectWriteAvailableReadTls12() throws Exception {
assertChannelConnectedWriteAvailableRead(TLS_1_2);
}
@EnabledIf(TLS_1_3_SUPPORTED)
@Test
public void testClientConnectWriteAvailableReadTls13() throws Exception {
assertChannelConnectedWriteAvailableRead(TLS_1_3);
}
@Test
public void testServerReadWriteTls12() throws Exception { public void testServerReadWriteTls12() throws Exception {
assumeProtocolSupported(TLS_1_2);
assertServerChannelConnectedReadClosed(TLS_1_2); assertServerChannelConnectedReadClosed(TLS_1_2);
} }
@Test(timeout = CHANNEL_TIMEOUT) @EnabledIf(TLS_1_3_SUPPORTED)
@Test
public void testServerReadWriteTls13() throws Exception { public void testServerReadWriteTls13() throws Exception {
assumeProtocolSupported(TLS_1_3);
assertServerChannelConnectedReadClosed(TLS_1_3); assertServerChannelConnectedReadClosed(TLS_1_3);
} }
private void assumeProtocolSupported(final String protocol) {
Assume.assumeTrue(String.format("Protocol [%s] not supported", protocol), TlsPlatform.getSupportedProtocols().contains(protocol));
}
private void assertServerChannelConnectedReadClosed(final String enabledProtocol) throws IOException, InterruptedException { private void assertServerChannelConnectedReadClosed(final String enabledProtocol) throws IOException, InterruptedException {
final int port = NetworkUtils.getAvailableTcpPort(); final int port = NetworkUtils.getAvailableTcpPort();
final ServerSocketChannel serverSocketChannel = ServerSocketChannel.open(); final ServerSocketChannel serverSocketChannel = ServerSocketChannel.open();
@ -194,67 +211,100 @@ public class SSLSocketChannelTest {
channel.writeAndFlush(MESSAGE).syncUninterruptibly(); channel.writeAndFlush(MESSAGE).syncUninterruptibly();
final String messageRead = queue.poll(CHANNEL_POLL_TIMEOUT, TimeUnit.MILLISECONDS); final String messageRead = queue.poll(CHANNEL_POLL_TIMEOUT, TimeUnit.MILLISECONDS);
assertEquals("Message not matched", MESSAGE, messageRead); assertEquals(MESSAGE, messageRead, "Message not matched");
} finally { } finally {
channel.close(); channel.close();
} }
} finally { } finally {
group.shutdownGracefully().syncUninterruptibly(); shutdownGroup(group);
serverSocketChannel.close(); serverSocketChannel.close();
} }
} }
private void assertChannelConnectedWriteReadClosed(final String enabledProtocol) throws IOException { private void assertChannelConnectedWriteReadClosed(final String enabledProtocol) throws IOException {
processClientSslSocketChannel(enabledProtocol, (sslSocketChannel -> { final CountDownLatch countDownLatch = getSingleCountDownLatch();
processClientSslSocketChannel(enabledProtocol, countDownLatch, (sslSocketChannel -> {
try { try {
sslSocketChannel.connect(); sslSocketChannel.connect();
assertFalse("Channel closed", sslSocketChannel.isClosed()); assertFalse(sslSocketChannel.isClosed());
assertChannelWriteRead(sslSocketChannel); assertChannelWriteRead(sslSocketChannel, countDownLatch);
sslSocketChannel.close(); sslSocketChannel.close();
assertTrue("Channel not closed", sslSocketChannel.isClosed()); assertTrue(sslSocketChannel.isClosed());
} catch (final IOException e) { } catch (final IOException e) {
throw new UncheckedIOException(String.format("Channel Failed for %s", enabledProtocol), e); throw new UncheckedIOException(String.format("Channel Failed for %s", enabledProtocol), e);
} }
})); }));
} }
private void assertChannelWriteRead(final SSLSocketChannel sslSocketChannel) throws IOException { private void assertChannelConnectedWriteAvailableRead(final String enabledProtocol) throws IOException {
sslSocketChannel.write(MESSAGE_BYTES); final CountDownLatch countDownLatch = getSingleCountDownLatch();
processClientSslSocketChannel(enabledProtocol, countDownLatch, (sslSocketChannel -> {
while (sslSocketChannel.available() == 0) {
try { try {
TimeUnit.MILLISECONDS.sleep(CHANNEL_SLEEP_BEFORE_READ); sslSocketChannel.connect();
} catch (final InterruptedException e) { assertFalse(sslSocketChannel.isClosed());
throw new RuntimeException(e);
}
}
assertChannelWriteAvailableRead(sslSocketChannel, countDownLatch);
sslSocketChannel.close();
assertTrue(sslSocketChannel.isClosed());
} catch (final IOException e) {
throw new UncheckedIOException(String.format("Channel Failed for %s", enabledProtocol), e);
}
}));
}
private void assertChannelWriteAvailableRead(final SSLSocketChannel sslSocketChannel, final CountDownLatch countDownLatch) throws IOException {
sslSocketChannel.write(MESSAGE_BYTES);
sslSocketChannel.available();
awaitCountDownLatch(countDownLatch);
assetMessageRead(sslSocketChannel);
}
private void assertChannelWriteRead(final SSLSocketChannel sslSocketChannel, final CountDownLatch countDownLatch) throws IOException {
sslSocketChannel.write(MESSAGE_BYTES);
awaitCountDownLatch(countDownLatch);
assetMessageRead(sslSocketChannel);
}
private void awaitCountDownLatch(final CountDownLatch countDownLatch) throws IOException {
try {
countDownLatch.await();
} catch (final InterruptedException e) {
throw new IOException("Count Down Interrupted", e);
}
}
private void assetMessageRead(final SSLSocketChannel sslSocketChannel) throws IOException {
final byte firstByteRead = (byte) sslSocketChannel.read(); final byte firstByteRead = (byte) sslSocketChannel.read();
assertEquals("Channel Message first byte not matched", MESSAGE_BYTES[0], firstByteRead); assertEquals(MESSAGE_BYTES[0], firstByteRead, "Channel Message first byte not matched");
final int available = sslSocketChannel.available();
final int availableExpected = MESSAGE_BYTES.length - FIRST_BYTE_OFFSET;
assertEquals(availableExpected, available, "Available Bytes not matched");
final byte[] messageBytes = new byte[MESSAGE_BYTES.length]; final byte[] messageBytes = new byte[MESSAGE_BYTES.length];
messageBytes[0] = firstByteRead; messageBytes[0] = firstByteRead;
final int messageBytesRead = sslSocketChannel.read(messageBytes, FIRST_BYTE_OFFSET, messageBytes.length); final int messageBytesRead = sslSocketChannel.read(messageBytes, FIRST_BYTE_OFFSET, messageBytes.length);
assertEquals("Channel Message Bytes Read not matched", messageBytes.length - FIRST_BYTE_OFFSET, messageBytesRead); assertEquals(messageBytes.length - FIRST_BYTE_OFFSET, messageBytesRead, "Channel Message Bytes Read not matched");
final String message = new String(messageBytes, MESSAGE_CHARSET); final String message = new String(messageBytes, MESSAGE_CHARSET);
assertEquals("Channel Message not matched", MESSAGE, message); assertEquals(MESSAGE, message, "Message not matched");
} }
private void processClientSslSocketChannel(final String enabledProtocol, final Consumer<SSLSocketChannel> channelConsumer) throws IOException { private void processClientSslSocketChannel(final String enabledProtocol, final CountDownLatch countDownLatch, final Consumer<SSLSocketChannel> channelConsumer) throws IOException {
final EventLoopGroup group = new NioEventLoopGroup(GROUP_THREADS); final EventLoopGroup group = new NioEventLoopGroup(GROUP_THREADS);
try { try {
final int port = NetworkUtils.getAvailableTcpPort(); final int port = NetworkUtils.getAvailableTcpPort();
startServer(group, port, enabledProtocol); startServer(group, port, enabledProtocol, countDownLatch);
final SSLSocketChannel sslSocketChannel = new SSLSocketChannel(sslContext, LOCALHOST, port, null, CLIENT_CHANNEL); final SSLSocketChannel sslSocketChannel = new SSLSocketChannel(sslContext, LOCALHOST, port, null, CLIENT_CHANNEL);
sslSocketChannel.setTimeout(CHANNEL_TIMEOUT); sslSocketChannel.setTimeout(CHANNEL_TIMEOUT);
channelConsumer.accept(sslSocketChannel); channelConsumer.accept(sslSocketChannel);
} finally { } finally {
group.shutdownGracefully().syncUninterruptibly(); shutdownGroup(group);
} }
} }
@ -273,7 +323,7 @@ public class SSLSocketChannelTest {
return bootstrap.connect(LOCALHOST, port).syncUninterruptibly().channel(); return bootstrap.connect(LOCALHOST, port).syncUninterruptibly().channel();
} }
private void startServer(final EventLoopGroup group, final int port, final String enabledProtocol) { private void startServer(final EventLoopGroup group, final int port, final String enabledProtocol, final CountDownLatch countDownLatch) {
final ServerBootstrap bootstrap = new ServerBootstrap(); final ServerBootstrap bootstrap = new ServerBootstrap();
bootstrap.group(group); bootstrap.group(group);
bootstrap.channel(NioServerSocketChannel.class); bootstrap.channel(NioServerSocketChannel.class);
@ -287,6 +337,7 @@ public class SSLSocketChannelTest {
@Override @Override
protected void channelRead0(ChannelHandlerContext channelHandlerContext, String s) throws Exception { protected void channelRead0(ChannelHandlerContext channelHandlerContext, String s) throws Exception {
channelHandlerContext.channel().writeAndFlush(MESSAGE).sync(); channelHandlerContext.channel().writeAndFlush(MESSAGE).sync();
countDownLatch.countDown();
} }
}); });
} }
@ -309,4 +360,12 @@ public class SSLSocketChannelTest {
pipeline.addLast(new StringDecoder()); pipeline.addLast(new StringDecoder());
pipeline.addLast(new StringEncoder()); pipeline.addLast(new StringEncoder());
} }
private void shutdownGroup(final EventLoopGroup group) {
group.shutdownGracefully(SHUTDOWN_TIMEOUT, SHUTDOWN_TIMEOUT, TimeUnit.MILLISECONDS).syncUninterruptibly();
}
private CountDownLatch getSingleCountDownLatch() {
return new CountDownLatch(SINGLE_COUNT_DOWN);
}
} }