mirror of https://github.com/apache/nifi.git
NIFI-9253 Corrected SSLSocketChannel.available() for TLSv1.3
- Added unit tests to reproduce issues with available() method - Changed available() to return size of application buffer - Removed unused isDataAvailable() - Refactored unwrap handling to read from channel for buffer underflow Signed-off-by: Pierre Villard <pierre.villard.fr@gmail.com> This closes #5421.
This commit is contained in:
parent
ae0154de5a
commit
defea61075
|
@ -30,6 +30,7 @@ import javax.net.ssl.SSLException;
|
||||||
import javax.net.ssl.SSLHandshakeException;
|
import javax.net.ssl.SSLHandshakeException;
|
||||||
import javax.net.ssl.SSLSession;
|
import javax.net.ssl.SSLSession;
|
||||||
import java.io.Closeable;
|
import java.io.Closeable;
|
||||||
|
import java.io.EOFException;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.net.InetAddress;
|
import java.net.InetAddress;
|
||||||
import java.net.InetSocketAddress;
|
import java.net.InetSocketAddress;
|
||||||
|
@ -47,6 +48,7 @@ import java.util.concurrent.TimeUnit;
|
||||||
public class SSLSocketChannel implements Closeable {
|
public class SSLSocketChannel implements Closeable {
|
||||||
private static final Logger LOGGER = LoggerFactory.getLogger(SSLSocketChannel.class);
|
private static final Logger LOGGER = LoggerFactory.getLogger(SSLSocketChannel.class);
|
||||||
|
|
||||||
|
private static final int MINIMUM_READ_BUFFER_SIZE = 1;
|
||||||
private static final int DISCARD_BUFFER_LENGTH = 8192;
|
private static final int DISCARD_BUFFER_LENGTH = 8192;
|
||||||
private static final int END_OF_STREAM = -1;
|
private static final int END_OF_STREAM = -1;
|
||||||
private static final byte[] EMPTY_MESSAGE = new byte[0];
|
private static final byte[] EMPTY_MESSAGE = new byte[0];
|
||||||
|
@ -266,7 +268,7 @@ public class SSLSocketChannel implements Closeable {
|
||||||
status = wrapResult.getStatus();
|
status = wrapResult.getStatus();
|
||||||
}
|
}
|
||||||
if (Status.CLOSED == status) {
|
if (Status.CLOSED == status) {
|
||||||
final ByteBuffer streamOutputBuffer = streamOutManager.prepareForRead(1);
|
final ByteBuffer streamOutputBuffer = streamOutManager.prepareForRead(MINIMUM_READ_BUFFER_SIZE);
|
||||||
try {
|
try {
|
||||||
writeChannel(streamOutputBuffer);
|
writeChannel(streamOutputBuffer);
|
||||||
} catch (final IOException e) {
|
} catch (final IOException e) {
|
||||||
|
@ -291,39 +293,8 @@ public class SSLSocketChannel implements Closeable {
|
||||||
* @throws IOException Thrown on failures checking for available bytes
|
* @throws IOException Thrown on failures checking for available bytes
|
||||||
*/
|
*/
|
||||||
public int available() throws IOException {
|
public int available() throws IOException {
|
||||||
ByteBuffer appDataBuffer = appDataManager.prepareForRead(1);
|
final ByteBuffer appDataBuffer = appDataManager.prepareForRead(MINIMUM_READ_BUFFER_SIZE);
|
||||||
ByteBuffer streamDataBuffer = streamInManager.prepareForRead(1);
|
return appDataBuffer.remaining();
|
||||||
final int buffered = appDataBuffer.remaining() + streamDataBuffer.remaining();
|
|
||||||
if (buffered > 0) {
|
|
||||||
return buffered;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!isDataAvailable()) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
appDataBuffer = appDataManager.prepareForRead(1);
|
|
||||||
streamDataBuffer = streamInManager.prepareForRead(1);
|
|
||||||
return appDataBuffer.remaining() + streamDataBuffer.remaining();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Is data available for reading
|
|
||||||
*
|
|
||||||
* @return Data available status
|
|
||||||
* @throws IOException Thrown on SocketChannel.read() failures
|
|
||||||
*/
|
|
||||||
public boolean isDataAvailable() throws IOException {
|
|
||||||
final ByteBuffer appDataBuffer = appDataManager.prepareForRead(1);
|
|
||||||
final ByteBuffer streamDataBuffer = streamInManager.prepareForRead(1);
|
|
||||||
|
|
||||||
if (appDataBuffer.remaining() > 0 || streamDataBuffer.remaining() > 0) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
final ByteBuffer writableBuffer = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
|
|
||||||
final int bytesRead = channel.read(writableBuffer);
|
|
||||||
return (bytesRead > 0);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -373,42 +344,24 @@ public class SSLSocketChannel implements Closeable {
|
||||||
}
|
}
|
||||||
appDataManager.clear();
|
appDataManager.clear();
|
||||||
|
|
||||||
while (true) {
|
final SSLEngineResult unwrapResult = unwrapBufferReadChannel();
|
||||||
final SSLEngineResult unwrapResult = unwrap();
|
final Status status = unwrapResult.getStatus();
|
||||||
|
if (Status.CLOSED == status) {
|
||||||
if (SSLEngineResult.HandshakeStatus.FINISHED == unwrapResult.getHandshakeStatus()) {
|
applicationBytesRead = readApplicationBuffer(buffer, offset, len);
|
||||||
// RFC 8446 Section 4.6 describes Post-Handshake Messages for TLS 1.3
|
if (applicationBytesRead == 0) {
|
||||||
logOperation("Processing Post-Handshake Messages");
|
return END_OF_STREAM;
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
|
streamInManager.compact();
|
||||||
final Status status = unwrapResult.getStatus();
|
return applicationBytesRead;
|
||||||
switch (status) {
|
} else if (Status.OK == status) {
|
||||||
case BUFFER_OVERFLOW:
|
applicationBytesRead = readApplicationBuffer(buffer, offset, len);
|
||||||
throw new IllegalStateException(String.format("SSLEngineResult Status [%s] not allowed from unwrap", status));
|
if (applicationBytesRead == 0) {
|
||||||
case BUFFER_UNDERFLOW:
|
throw new IOException("Read Application Buffer Failed");
|
||||||
final ByteBuffer streamBuffer = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
|
|
||||||
final int channelBytesRead = readChannel(streamBuffer);
|
|
||||||
logOperationBytes("Channel Read Completed", channelBytesRead);
|
|
||||||
if (channelBytesRead == END_OF_STREAM) {
|
|
||||||
return END_OF_STREAM;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case CLOSED:
|
|
||||||
applicationBytesRead = readApplicationBuffer(buffer, offset, len);
|
|
||||||
if (applicationBytesRead == 0) {
|
|
||||||
return END_OF_STREAM;
|
|
||||||
}
|
|
||||||
streamInManager.compact();
|
|
||||||
return applicationBytesRead;
|
|
||||||
case OK:
|
|
||||||
applicationBytesRead = readApplicationBuffer(buffer, offset, len);
|
|
||||||
if (applicationBytesRead == 0) {
|
|
||||||
throw new IOException("Read Application Buffer Failed");
|
|
||||||
}
|
|
||||||
streamInManager.compact();
|
|
||||||
return applicationBytesRead;
|
|
||||||
}
|
}
|
||||||
|
streamInManager.compact();
|
||||||
|
return applicationBytesRead;
|
||||||
|
} else {
|
||||||
|
throw new IllegalStateException(String.format("SSLEngineResult Status [%s] not expected from unwrap", status));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -508,24 +461,13 @@ public class SSLSocketChannel implements Closeable {
|
||||||
handshakeStatus = engine.getHandshakeStatus();
|
handshakeStatus = engine.getHandshakeStatus();
|
||||||
break;
|
break;
|
||||||
case NEED_UNWRAP:
|
case NEED_UNWRAP:
|
||||||
final SSLEngineResult unwrapResult = unwrap();
|
final SSLEngineResult unwrapResult = unwrapBufferReadChannel();
|
||||||
handshakeStatus = unwrapResult.getHandshakeStatus();
|
handshakeStatus = unwrapResult.getHandshakeStatus();
|
||||||
Status unwrapResultStatus = unwrapResult.getStatus();
|
if (unwrapResult.getStatus() == Status.CLOSED) {
|
||||||
|
|
||||||
if (unwrapResultStatus == Status.BUFFER_UNDERFLOW) {
|
|
||||||
final ByteBuffer writableDataIn = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
|
|
||||||
final int bytesRead = readChannel(writableDataIn);
|
|
||||||
logOperationBytes("Handshake Channel Read", bytesRead);
|
|
||||||
|
|
||||||
if (bytesRead == END_OF_STREAM) {
|
|
||||||
throw getHandshakeException(handshakeStatus, "End of Stream Found");
|
|
||||||
}
|
|
||||||
} else if (unwrapResultStatus == Status.CLOSED) {
|
|
||||||
throw getHandshakeException(handshakeStatus, "Channel Closed");
|
throw getHandshakeException(handshakeStatus, "Channel Closed");
|
||||||
} else {
|
|
||||||
streamInManager.compact();
|
|
||||||
appDataManager.clear();
|
|
||||||
}
|
}
|
||||||
|
streamInManager.compact();
|
||||||
|
appDataManager.clear();
|
||||||
break;
|
break;
|
||||||
case NEED_WRAP:
|
case NEED_WRAP:
|
||||||
final ByteBuffer outboundBuffer = streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
|
final ByteBuffer outboundBuffer = streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
|
||||||
|
@ -536,7 +478,7 @@ public class SSLSocketChannel implements Closeable {
|
||||||
if (wrapResultStatus == Status.BUFFER_OVERFLOW) {
|
if (wrapResultStatus == Status.BUFFER_OVERFLOW) {
|
||||||
streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
|
streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
|
||||||
} else if (wrapResultStatus == Status.OK) {
|
} else if (wrapResultStatus == Status.OK) {
|
||||||
final ByteBuffer streamBuffer = streamOutManager.prepareForRead(1);
|
final ByteBuffer streamBuffer = streamOutManager.prepareForRead(MINIMUM_READ_BUFFER_SIZE);
|
||||||
final int bytesRemaining = streamBuffer.remaining();
|
final int bytesRemaining = streamBuffer.remaining();
|
||||||
writeChannel(streamBuffer);
|
writeChannel(streamBuffer);
|
||||||
logOperationBytes("Handshake Channel Write Completed", bytesRemaining);
|
logOperationBytes("Handshake Channel Write Completed", bytesRemaining);
|
||||||
|
@ -549,8 +491,29 @@ public class SSLSocketChannel implements Closeable {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private int readChannel(final ByteBuffer outputBuffer) throws IOException {
|
private SSLEngineResult unwrapBufferReadChannel() throws IOException {
|
||||||
|
SSLEngineResult unwrapResult = unwrap();
|
||||||
|
|
||||||
|
while (Status.BUFFER_UNDERFLOW == unwrapResult.getStatus()) {
|
||||||
|
final int channelBytesRead = readChannel();
|
||||||
|
if (channelBytesRead == END_OF_STREAM) {
|
||||||
|
throw new EOFException("End of Stream found for Channel Read");
|
||||||
|
}
|
||||||
|
|
||||||
|
unwrapResult = unwrap();
|
||||||
|
if (SSLEngineResult.HandshakeStatus.FINISHED == unwrapResult.getHandshakeStatus()) {
|
||||||
|
// RFC 8446 Section 4.6 describes Post-Handshake Messages for TLS 1.3
|
||||||
|
logOperation("Processing Post-Handshake Messages");
|
||||||
|
unwrapResult = unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return unwrapResult;
|
||||||
|
}
|
||||||
|
|
||||||
|
private int readChannel() throws IOException {
|
||||||
logOperation("Channel Read Started");
|
logOperation("Channel Read Started");
|
||||||
|
final ByteBuffer outputBuffer = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
|
||||||
|
|
||||||
final long started = System.currentTimeMillis();
|
final long started = System.currentTimeMillis();
|
||||||
long sleepNanoseconds = INITIAL_INCREMENTAL_SLEEP;
|
long sleepNanoseconds = INITIAL_INCREMENTAL_SLEEP;
|
||||||
|
@ -568,10 +531,24 @@ public class SSLSocketChannel implements Closeable {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logOperationBytes("Channel Read Completed", channelBytesRead);
|
||||||
return channelBytesRead;
|
return channelBytesRead;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void readChannelDiscard() {
|
||||||
|
try {
|
||||||
|
final ByteBuffer readBuffer = ByteBuffer.allocate(DISCARD_BUFFER_LENGTH);
|
||||||
|
int bytesRead = channel.read(readBuffer);
|
||||||
|
while (bytesRead > 0) {
|
||||||
|
readBuffer.clear();
|
||||||
|
bytesRead = channel.read(readBuffer);
|
||||||
|
}
|
||||||
|
} catch (final IOException e) {
|
||||||
|
LOGGER.debug("[{}:{}] Read Channel Discard Failed", remoteAddress, port, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private void writeChannel(final ByteBuffer inputBuffer) throws IOException {
|
private void writeChannel(final ByteBuffer inputBuffer) throws IOException {
|
||||||
long lastWriteCompleted = System.currentTimeMillis();
|
long lastWriteCompleted = System.currentTimeMillis();
|
||||||
|
|
||||||
|
@ -605,19 +582,6 @@ public class SSLSocketChannel implements Closeable {
|
||||||
return Math.min(nanoseconds * 2, BUFFER_FULL_EMPTY_WAIT_NANOS);
|
return Math.min(nanoseconds * 2, BUFFER_FULL_EMPTY_WAIT_NANOS);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void readChannelDiscard() {
|
|
||||||
try {
|
|
||||||
final ByteBuffer readBuffer = ByteBuffer.allocate(DISCARD_BUFFER_LENGTH);
|
|
||||||
int bytesRead = channel.read(readBuffer);
|
|
||||||
while (bytesRead > 0) {
|
|
||||||
readBuffer.clear();
|
|
||||||
bytesRead = channel.read(readBuffer);
|
|
||||||
}
|
|
||||||
} catch (final IOException e) {
|
|
||||||
LOGGER.debug("[{}:{}] Read Channel Discard Failed", remoteAddress, port, e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private int readApplicationBuffer(final byte[] buffer, final int offset, final int len) {
|
private int readApplicationBuffer(final byte[] buffer, final int offset, final int len) {
|
||||||
logOperationBytes("Application Buffer Read Requested", len);
|
logOperationBytes("Application Buffer Read Requested", len);
|
||||||
final ByteBuffer appDataBuffer = appDataManager.prepareForRead(len);
|
final ByteBuffer appDataBuffer = appDataManager.prepareForRead(len);
|
||||||
|
|
|
@ -58,8 +58,4 @@ public class SSLSocketChannelInputStream extends InputStream {
|
||||||
public int available() throws IOException {
|
public int available() throws IOException {
|
||||||
return channel.available();
|
return channel.available();
|
||||||
}
|
}
|
||||||
|
|
||||||
public boolean isDataAvailable() throws IOException {
|
|
||||||
return available() > 0;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,9 +38,10 @@ import org.apache.nifi.security.util.SslContextFactory;
|
||||||
import org.apache.nifi.security.util.TemporaryKeyStoreBuilder;
|
import org.apache.nifi.security.util.TemporaryKeyStoreBuilder;
|
||||||
import org.apache.nifi.security.util.TlsConfiguration;
|
import org.apache.nifi.security.util.TlsConfiguration;
|
||||||
import org.apache.nifi.security.util.TlsPlatform;
|
import org.apache.nifi.security.util.TlsPlatform;
|
||||||
import org.junit.Assume;
|
import org.junit.jupiter.api.BeforeAll;
|
||||||
import org.junit.BeforeClass;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Timeout;
|
||||||
|
import org.junit.jupiter.api.condition.EnabledIf;
|
||||||
|
|
||||||
import javax.net.ssl.SSLContext;
|
import javax.net.ssl.SSLContext;
|
||||||
import javax.net.ssl.SSLEngine;
|
import javax.net.ssl.SSLEngine;
|
||||||
|
@ -55,17 +56,19 @@ import java.nio.charset.Charset;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.security.GeneralSecurityException;
|
import java.security.GeneralSecurityException;
|
||||||
import java.util.concurrent.BlockingQueue;
|
import java.util.concurrent.BlockingQueue;
|
||||||
|
import java.util.concurrent.CountDownLatch;
|
||||||
import java.util.concurrent.Executor;
|
import java.util.concurrent.Executor;
|
||||||
import java.util.concurrent.Executors;
|
import java.util.concurrent.Executors;
|
||||||
import java.util.concurrent.LinkedBlockingQueue;
|
import java.util.concurrent.LinkedBlockingQueue;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
import java.util.function.Consumer;
|
import java.util.function.Consumer;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.Assert.assertFalse;
|
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||||
import static org.junit.Assert.assertThrows;
|
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
|
@Timeout(value = 15)
|
||||||
public class SSLSocketChannelTest {
|
public class SSLSocketChannelTest {
|
||||||
private static final String LOCALHOST = "localhost";
|
private static final String LOCALHOST = "localhost";
|
||||||
|
|
||||||
|
@ -81,10 +84,10 @@ public class SSLSocketChannelTest {
|
||||||
|
|
||||||
private static final int CHANNEL_POLL_TIMEOUT = 5000;
|
private static final int CHANNEL_POLL_TIMEOUT = 5000;
|
||||||
|
|
||||||
private static final long CHANNEL_SLEEP_BEFORE_READ = 100;
|
|
||||||
|
|
||||||
private static final int MAX_MESSAGE_LENGTH = 1024;
|
private static final int MAX_MESSAGE_LENGTH = 1024;
|
||||||
|
|
||||||
|
private static final long SHUTDOWN_TIMEOUT = 100;
|
||||||
|
|
||||||
private static final String TLS_1_3 = "TLSv1.3";
|
private static final String TLS_1_3 = "TLSv1.3";
|
||||||
|
|
||||||
private static final String TLS_1_2 = "TLSv1.2";
|
private static final String TLS_1_2 = "TLSv1.2";
|
||||||
|
@ -97,9 +100,17 @@ public class SSLSocketChannelTest {
|
||||||
|
|
||||||
private static final int FIRST_BYTE_OFFSET = 1;
|
private static final int FIRST_BYTE_OFFSET = 1;
|
||||||
|
|
||||||
|
private static final int SINGLE_COUNT_DOWN = 1;
|
||||||
|
|
||||||
private static SSLContext sslContext;
|
private static SSLContext sslContext;
|
||||||
|
|
||||||
@BeforeClass
|
private static final String TLS_1_3_SUPPORTED = "isTls13Supported";
|
||||||
|
|
||||||
|
public static boolean isTls13Supported() {
|
||||||
|
return TlsPlatform.getSupportedProtocols().contains(TLS_1_3);
|
||||||
|
}
|
||||||
|
|
||||||
|
@BeforeAll
|
||||||
public static void setConfiguration() throws GeneralSecurityException {
|
public static void setConfiguration() throws GeneralSecurityException {
|
||||||
final TlsConfiguration tlsConfiguration = new TemporaryKeyStoreBuilder().build();
|
final TlsConfiguration tlsConfiguration = new TemporaryKeyStoreBuilder().build();
|
||||||
sslContext = SslContextFactory.createSslContext(tlsConfiguration);
|
sslContext = SslContextFactory.createSslContext(tlsConfiguration);
|
||||||
|
@ -115,54 +126,60 @@ public class SSLSocketChannelTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testClientConnectHandshakeFailed() throws IOException {
|
public void testClientConnectHandshakeFailed() throws IOException {
|
||||||
assumeProtocolSupported(TLS_1_2);
|
final String enabledProtocol = isTls13Supported() ? TLS_1_3 : TLS_1_2;
|
||||||
|
|
||||||
final EventLoopGroup group = new NioEventLoopGroup(GROUP_THREADS);
|
final EventLoopGroup group = new NioEventLoopGroup(GROUP_THREADS);
|
||||||
|
|
||||||
try (final SocketChannel socketChannel = SocketChannel.open()) {
|
try (final SocketChannel socketChannel = SocketChannel.open()) {
|
||||||
final int port = NetworkUtils.getAvailableTcpPort();
|
final int port = NetworkUtils.getAvailableTcpPort();
|
||||||
startServer(group, port, TLS_1_2);
|
startServer(group, port, enabledProtocol, getSingleCountDownLatch());
|
||||||
|
|
||||||
socketChannel.connect(new InetSocketAddress(LOCALHOST, port));
|
socketChannel.connect(new InetSocketAddress(LOCALHOST, port));
|
||||||
final SSLEngine sslEngine = createSslEngine(TLS_1_2, CLIENT_CHANNEL);
|
final SSLEngine sslEngine = createSslEngine(enabledProtocol, CLIENT_CHANNEL);
|
||||||
|
|
||||||
final SSLSocketChannel sslSocketChannel = new SSLSocketChannel(sslEngine, socketChannel);
|
final SSLSocketChannel sslSocketChannel = new SSLSocketChannel(sslEngine, socketChannel);
|
||||||
sslSocketChannel.setTimeout(CHANNEL_FAILURE_TIMEOUT);
|
sslSocketChannel.setTimeout(CHANNEL_FAILURE_TIMEOUT);
|
||||||
|
|
||||||
group.shutdownGracefully().syncUninterruptibly();
|
shutdownGroup(group);
|
||||||
assertThrows(SSLException.class, sslSocketChannel::connect);
|
assertThrows(SSLException.class, sslSocketChannel::connect);
|
||||||
} finally {
|
} finally {
|
||||||
group.shutdownGracefully().syncUninterruptibly();
|
shutdownGroup(group);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testClientConnectWriteReadTls12() throws Exception {
|
public void testClientConnectWriteReadTls12() throws Exception {
|
||||||
assumeProtocolSupported(TLS_1_2);
|
|
||||||
assertChannelConnectedWriteReadClosed(TLS_1_2);
|
assertChannelConnectedWriteReadClosed(TLS_1_2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@EnabledIf(TLS_1_3_SUPPORTED)
|
||||||
@Test
|
@Test
|
||||||
public void testClientConnectWriteReadTls13() throws Exception {
|
public void testClientConnectWriteReadTls13() throws Exception {
|
||||||
assumeProtocolSupported(TLS_1_3);
|
|
||||||
assertChannelConnectedWriteReadClosed(TLS_1_3);
|
assertChannelConnectedWriteReadClosed(TLS_1_3);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = CHANNEL_TIMEOUT)
|
@Test
|
||||||
|
public void testClientConnectWriteAvailableReadTls12() throws Exception {
|
||||||
|
assertChannelConnectedWriteAvailableRead(TLS_1_2);
|
||||||
|
}
|
||||||
|
|
||||||
|
@EnabledIf(TLS_1_3_SUPPORTED)
|
||||||
|
@Test
|
||||||
|
public void testClientConnectWriteAvailableReadTls13() throws Exception {
|
||||||
|
assertChannelConnectedWriteAvailableRead(TLS_1_3);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
public void testServerReadWriteTls12() throws Exception {
|
public void testServerReadWriteTls12() throws Exception {
|
||||||
assumeProtocolSupported(TLS_1_2);
|
|
||||||
assertServerChannelConnectedReadClosed(TLS_1_2);
|
assertServerChannelConnectedReadClosed(TLS_1_2);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = CHANNEL_TIMEOUT)
|
@EnabledIf(TLS_1_3_SUPPORTED)
|
||||||
|
@Test
|
||||||
public void testServerReadWriteTls13() throws Exception {
|
public void testServerReadWriteTls13() throws Exception {
|
||||||
assumeProtocolSupported(TLS_1_3);
|
|
||||||
assertServerChannelConnectedReadClosed(TLS_1_3);
|
assertServerChannelConnectedReadClosed(TLS_1_3);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void assumeProtocolSupported(final String protocol) {
|
|
||||||
Assume.assumeTrue(String.format("Protocol [%s] not supported", protocol), TlsPlatform.getSupportedProtocols().contains(protocol));
|
|
||||||
}
|
|
||||||
|
|
||||||
private void assertServerChannelConnectedReadClosed(final String enabledProtocol) throws IOException, InterruptedException {
|
private void assertServerChannelConnectedReadClosed(final String enabledProtocol) throws IOException, InterruptedException {
|
||||||
final int port = NetworkUtils.getAvailableTcpPort();
|
final int port = NetworkUtils.getAvailableTcpPort();
|
||||||
final ServerSocketChannel serverSocketChannel = ServerSocketChannel.open();
|
final ServerSocketChannel serverSocketChannel = ServerSocketChannel.open();
|
||||||
|
@ -194,67 +211,100 @@ public class SSLSocketChannelTest {
|
||||||
channel.writeAndFlush(MESSAGE).syncUninterruptibly();
|
channel.writeAndFlush(MESSAGE).syncUninterruptibly();
|
||||||
|
|
||||||
final String messageRead = queue.poll(CHANNEL_POLL_TIMEOUT, TimeUnit.MILLISECONDS);
|
final String messageRead = queue.poll(CHANNEL_POLL_TIMEOUT, TimeUnit.MILLISECONDS);
|
||||||
assertEquals("Message not matched", MESSAGE, messageRead);
|
assertEquals(MESSAGE, messageRead, "Message not matched");
|
||||||
} finally {
|
} finally {
|
||||||
channel.close();
|
channel.close();
|
||||||
}
|
}
|
||||||
} finally {
|
} finally {
|
||||||
group.shutdownGracefully().syncUninterruptibly();
|
shutdownGroup(group);
|
||||||
serverSocketChannel.close();
|
serverSocketChannel.close();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void assertChannelConnectedWriteReadClosed(final String enabledProtocol) throws IOException {
|
private void assertChannelConnectedWriteReadClosed(final String enabledProtocol) throws IOException {
|
||||||
processClientSslSocketChannel(enabledProtocol, (sslSocketChannel -> {
|
final CountDownLatch countDownLatch = getSingleCountDownLatch();
|
||||||
|
processClientSslSocketChannel(enabledProtocol, countDownLatch, (sslSocketChannel -> {
|
||||||
try {
|
try {
|
||||||
sslSocketChannel.connect();
|
sslSocketChannel.connect();
|
||||||
assertFalse("Channel closed", sslSocketChannel.isClosed());
|
assertFalse(sslSocketChannel.isClosed());
|
||||||
|
|
||||||
assertChannelWriteRead(sslSocketChannel);
|
assertChannelWriteRead(sslSocketChannel, countDownLatch);
|
||||||
|
|
||||||
sslSocketChannel.close();
|
sslSocketChannel.close();
|
||||||
assertTrue("Channel not closed", sslSocketChannel.isClosed());
|
assertTrue(sslSocketChannel.isClosed());
|
||||||
} catch (final IOException e) {
|
} catch (final IOException e) {
|
||||||
throw new UncheckedIOException(String.format("Channel Failed for %s", enabledProtocol), e);
|
throw new UncheckedIOException(String.format("Channel Failed for %s", enabledProtocol), e);
|
||||||
}
|
}
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
private void assertChannelWriteRead(final SSLSocketChannel sslSocketChannel) throws IOException {
|
private void assertChannelConnectedWriteAvailableRead(final String enabledProtocol) throws IOException {
|
||||||
sslSocketChannel.write(MESSAGE_BYTES);
|
final CountDownLatch countDownLatch = getSingleCountDownLatch();
|
||||||
|
processClientSslSocketChannel(enabledProtocol, countDownLatch, (sslSocketChannel -> {
|
||||||
while (sslSocketChannel.available() == 0) {
|
|
||||||
try {
|
try {
|
||||||
TimeUnit.MILLISECONDS.sleep(CHANNEL_SLEEP_BEFORE_READ);
|
sslSocketChannel.connect();
|
||||||
} catch (final InterruptedException e) {
|
assertFalse(sslSocketChannel.isClosed());
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
assertChannelWriteAvailableRead(sslSocketChannel, countDownLatch);
|
||||||
|
|
||||||
|
sslSocketChannel.close();
|
||||||
|
assertTrue(sslSocketChannel.isClosed());
|
||||||
|
} catch (final IOException e) {
|
||||||
|
throw new UncheckedIOException(String.format("Channel Failed for %s", enabledProtocol), e);
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
private void assertChannelWriteAvailableRead(final SSLSocketChannel sslSocketChannel, final CountDownLatch countDownLatch) throws IOException {
|
||||||
|
sslSocketChannel.write(MESSAGE_BYTES);
|
||||||
|
sslSocketChannel.available();
|
||||||
|
awaitCountDownLatch(countDownLatch);
|
||||||
|
assetMessageRead(sslSocketChannel);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void assertChannelWriteRead(final SSLSocketChannel sslSocketChannel, final CountDownLatch countDownLatch) throws IOException {
|
||||||
|
sslSocketChannel.write(MESSAGE_BYTES);
|
||||||
|
awaitCountDownLatch(countDownLatch);
|
||||||
|
assetMessageRead(sslSocketChannel);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void awaitCountDownLatch(final CountDownLatch countDownLatch) throws IOException {
|
||||||
|
try {
|
||||||
|
countDownLatch.await();
|
||||||
|
} catch (final InterruptedException e) {
|
||||||
|
throw new IOException("Count Down Interrupted", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void assetMessageRead(final SSLSocketChannel sslSocketChannel) throws IOException {
|
||||||
final byte firstByteRead = (byte) sslSocketChannel.read();
|
final byte firstByteRead = (byte) sslSocketChannel.read();
|
||||||
assertEquals("Channel Message first byte not matched", MESSAGE_BYTES[0], firstByteRead);
|
assertEquals(MESSAGE_BYTES[0], firstByteRead, "Channel Message first byte not matched");
|
||||||
|
|
||||||
|
final int available = sslSocketChannel.available();
|
||||||
|
final int availableExpected = MESSAGE_BYTES.length - FIRST_BYTE_OFFSET;
|
||||||
|
assertEquals(availableExpected, available, "Available Bytes not matched");
|
||||||
|
|
||||||
final byte[] messageBytes = new byte[MESSAGE_BYTES.length];
|
final byte[] messageBytes = new byte[MESSAGE_BYTES.length];
|
||||||
messageBytes[0] = firstByteRead;
|
messageBytes[0] = firstByteRead;
|
||||||
|
|
||||||
final int messageBytesRead = sslSocketChannel.read(messageBytes, FIRST_BYTE_OFFSET, messageBytes.length);
|
final int messageBytesRead = sslSocketChannel.read(messageBytes, FIRST_BYTE_OFFSET, messageBytes.length);
|
||||||
assertEquals("Channel Message Bytes Read not matched", messageBytes.length - FIRST_BYTE_OFFSET, messageBytesRead);
|
assertEquals(messageBytes.length - FIRST_BYTE_OFFSET, messageBytesRead, "Channel Message Bytes Read not matched");
|
||||||
|
|
||||||
final String message = new String(messageBytes, MESSAGE_CHARSET);
|
final String message = new String(messageBytes, MESSAGE_CHARSET);
|
||||||
assertEquals("Channel Message not matched", MESSAGE, message);
|
assertEquals(MESSAGE, message, "Message not matched");
|
||||||
}
|
}
|
||||||
|
|
||||||
private void processClientSslSocketChannel(final String enabledProtocol, final Consumer<SSLSocketChannel> channelConsumer) throws IOException {
|
private void processClientSslSocketChannel(final String enabledProtocol, final CountDownLatch countDownLatch, final Consumer<SSLSocketChannel> channelConsumer) throws IOException {
|
||||||
final EventLoopGroup group = new NioEventLoopGroup(GROUP_THREADS);
|
final EventLoopGroup group = new NioEventLoopGroup(GROUP_THREADS);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
final int port = NetworkUtils.getAvailableTcpPort();
|
final int port = NetworkUtils.getAvailableTcpPort();
|
||||||
startServer(group, port, enabledProtocol);
|
startServer(group, port, enabledProtocol, countDownLatch);
|
||||||
final SSLSocketChannel sslSocketChannel = new SSLSocketChannel(sslContext, LOCALHOST, port, null, CLIENT_CHANNEL);
|
final SSLSocketChannel sslSocketChannel = new SSLSocketChannel(sslContext, LOCALHOST, port, null, CLIENT_CHANNEL);
|
||||||
sslSocketChannel.setTimeout(CHANNEL_TIMEOUT);
|
sslSocketChannel.setTimeout(CHANNEL_TIMEOUT);
|
||||||
channelConsumer.accept(sslSocketChannel);
|
channelConsumer.accept(sslSocketChannel);
|
||||||
} finally {
|
} finally {
|
||||||
group.shutdownGracefully().syncUninterruptibly();
|
shutdownGroup(group);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -273,7 +323,7 @@ public class SSLSocketChannelTest {
|
||||||
return bootstrap.connect(LOCALHOST, port).syncUninterruptibly().channel();
|
return bootstrap.connect(LOCALHOST, port).syncUninterruptibly().channel();
|
||||||
}
|
}
|
||||||
|
|
||||||
private void startServer(final EventLoopGroup group, final int port, final String enabledProtocol) {
|
private void startServer(final EventLoopGroup group, final int port, final String enabledProtocol, final CountDownLatch countDownLatch) {
|
||||||
final ServerBootstrap bootstrap = new ServerBootstrap();
|
final ServerBootstrap bootstrap = new ServerBootstrap();
|
||||||
bootstrap.group(group);
|
bootstrap.group(group);
|
||||||
bootstrap.channel(NioServerSocketChannel.class);
|
bootstrap.channel(NioServerSocketChannel.class);
|
||||||
|
@ -287,6 +337,7 @@ public class SSLSocketChannelTest {
|
||||||
@Override
|
@Override
|
||||||
protected void channelRead0(ChannelHandlerContext channelHandlerContext, String s) throws Exception {
|
protected void channelRead0(ChannelHandlerContext channelHandlerContext, String s) throws Exception {
|
||||||
channelHandlerContext.channel().writeAndFlush(MESSAGE).sync();
|
channelHandlerContext.channel().writeAndFlush(MESSAGE).sync();
|
||||||
|
countDownLatch.countDown();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -309,4 +360,12 @@ public class SSLSocketChannelTest {
|
||||||
pipeline.addLast(new StringDecoder());
|
pipeline.addLast(new StringDecoder());
|
||||||
pipeline.addLast(new StringEncoder());
|
pipeline.addLast(new StringEncoder());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void shutdownGroup(final EventLoopGroup group) {
|
||||||
|
group.shutdownGracefully(SHUTDOWN_TIMEOUT, SHUTDOWN_TIMEOUT, TimeUnit.MILLISECONDS).syncUninterruptibly();
|
||||||
|
}
|
||||||
|
|
||||||
|
private CountDownLatch getSingleCountDownLatch() {
|
||||||
|
return new CountDownLatch(SINGLE_COUNT_DOWN);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue