mirror of https://github.com/apache/nifi.git
NIFI-9253 Corrected SSLSocketChannel.available() for TLSv1.3
- Added unit tests to reproduce issues with available() method - Changed available() to return size of application buffer - Removed unused isDataAvailable() - Refactored unwrap handling to read from channel for buffer underflow Signed-off-by: Pierre Villard <pierre.villard.fr@gmail.com> This closes #5421.
This commit is contained in:
parent
ae0154de5a
commit
defea61075
|
@ -30,6 +30,7 @@ import javax.net.ssl.SSLException;
|
|||
import javax.net.ssl.SSLHandshakeException;
|
||||
import javax.net.ssl.SSLSession;
|
||||
import java.io.Closeable;
|
||||
import java.io.EOFException;
|
||||
import java.io.IOException;
|
||||
import java.net.InetAddress;
|
||||
import java.net.InetSocketAddress;
|
||||
|
@ -47,6 +48,7 @@ import java.util.concurrent.TimeUnit;
|
|||
public class SSLSocketChannel implements Closeable {
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(SSLSocketChannel.class);
|
||||
|
||||
private static final int MINIMUM_READ_BUFFER_SIZE = 1;
|
||||
private static final int DISCARD_BUFFER_LENGTH = 8192;
|
||||
private static final int END_OF_STREAM = -1;
|
||||
private static final byte[] EMPTY_MESSAGE = new byte[0];
|
||||
|
@ -266,7 +268,7 @@ public class SSLSocketChannel implements Closeable {
|
|||
status = wrapResult.getStatus();
|
||||
}
|
||||
if (Status.CLOSED == status) {
|
||||
final ByteBuffer streamOutputBuffer = streamOutManager.prepareForRead(1);
|
||||
final ByteBuffer streamOutputBuffer = streamOutManager.prepareForRead(MINIMUM_READ_BUFFER_SIZE);
|
||||
try {
|
||||
writeChannel(streamOutputBuffer);
|
||||
} catch (final IOException e) {
|
||||
|
@ -291,39 +293,8 @@ public class SSLSocketChannel implements Closeable {
|
|||
* @throws IOException Thrown on failures checking for available bytes
|
||||
*/
|
||||
public int available() throws IOException {
|
||||
ByteBuffer appDataBuffer = appDataManager.prepareForRead(1);
|
||||
ByteBuffer streamDataBuffer = streamInManager.prepareForRead(1);
|
||||
final int buffered = appDataBuffer.remaining() + streamDataBuffer.remaining();
|
||||
if (buffered > 0) {
|
||||
return buffered;
|
||||
}
|
||||
|
||||
if (!isDataAvailable()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
appDataBuffer = appDataManager.prepareForRead(1);
|
||||
streamDataBuffer = streamInManager.prepareForRead(1);
|
||||
return appDataBuffer.remaining() + streamDataBuffer.remaining();
|
||||
}
|
||||
|
||||
/**
|
||||
* Is data available for reading
|
||||
*
|
||||
* @return Data available status
|
||||
* @throws IOException Thrown on SocketChannel.read() failures
|
||||
*/
|
||||
public boolean isDataAvailable() throws IOException {
|
||||
final ByteBuffer appDataBuffer = appDataManager.prepareForRead(1);
|
||||
final ByteBuffer streamDataBuffer = streamInManager.prepareForRead(1);
|
||||
|
||||
if (appDataBuffer.remaining() > 0 || streamDataBuffer.remaining() > 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
final ByteBuffer writableBuffer = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
|
||||
final int bytesRead = channel.read(writableBuffer);
|
||||
return (bytesRead > 0);
|
||||
final ByteBuffer appDataBuffer = appDataManager.prepareForRead(MINIMUM_READ_BUFFER_SIZE);
|
||||
return appDataBuffer.remaining();
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -373,42 +344,24 @@ public class SSLSocketChannel implements Closeable {
|
|||
}
|
||||
appDataManager.clear();
|
||||
|
||||
while (true) {
|
||||
final SSLEngineResult unwrapResult = unwrap();
|
||||
|
||||
if (SSLEngineResult.HandshakeStatus.FINISHED == unwrapResult.getHandshakeStatus()) {
|
||||
// RFC 8446 Section 4.6 describes Post-Handshake Messages for TLS 1.3
|
||||
logOperation("Processing Post-Handshake Messages");
|
||||
continue;
|
||||
final SSLEngineResult unwrapResult = unwrapBufferReadChannel();
|
||||
final Status status = unwrapResult.getStatus();
|
||||
if (Status.CLOSED == status) {
|
||||
applicationBytesRead = readApplicationBuffer(buffer, offset, len);
|
||||
if (applicationBytesRead == 0) {
|
||||
return END_OF_STREAM;
|
||||
}
|
||||
|
||||
final Status status = unwrapResult.getStatus();
|
||||
switch (status) {
|
||||
case BUFFER_OVERFLOW:
|
||||
throw new IllegalStateException(String.format("SSLEngineResult Status [%s] not allowed from unwrap", status));
|
||||
case BUFFER_UNDERFLOW:
|
||||
final ByteBuffer streamBuffer = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
|
||||
final int channelBytesRead = readChannel(streamBuffer);
|
||||
logOperationBytes("Channel Read Completed", channelBytesRead);
|
||||
if (channelBytesRead == END_OF_STREAM) {
|
||||
return END_OF_STREAM;
|
||||
}
|
||||
break;
|
||||
case CLOSED:
|
||||
applicationBytesRead = readApplicationBuffer(buffer, offset, len);
|
||||
if (applicationBytesRead == 0) {
|
||||
return END_OF_STREAM;
|
||||
}
|
||||
streamInManager.compact();
|
||||
return applicationBytesRead;
|
||||
case OK:
|
||||
applicationBytesRead = readApplicationBuffer(buffer, offset, len);
|
||||
if (applicationBytesRead == 0) {
|
||||
throw new IOException("Read Application Buffer Failed");
|
||||
}
|
||||
streamInManager.compact();
|
||||
return applicationBytesRead;
|
||||
streamInManager.compact();
|
||||
return applicationBytesRead;
|
||||
} else if (Status.OK == status) {
|
||||
applicationBytesRead = readApplicationBuffer(buffer, offset, len);
|
||||
if (applicationBytesRead == 0) {
|
||||
throw new IOException("Read Application Buffer Failed");
|
||||
}
|
||||
streamInManager.compact();
|
||||
return applicationBytesRead;
|
||||
} else {
|
||||
throw new IllegalStateException(String.format("SSLEngineResult Status [%s] not expected from unwrap", status));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -508,24 +461,13 @@ public class SSLSocketChannel implements Closeable {
|
|||
handshakeStatus = engine.getHandshakeStatus();
|
||||
break;
|
||||
case NEED_UNWRAP:
|
||||
final SSLEngineResult unwrapResult = unwrap();
|
||||
final SSLEngineResult unwrapResult = unwrapBufferReadChannel();
|
||||
handshakeStatus = unwrapResult.getHandshakeStatus();
|
||||
Status unwrapResultStatus = unwrapResult.getStatus();
|
||||
|
||||
if (unwrapResultStatus == Status.BUFFER_UNDERFLOW) {
|
||||
final ByteBuffer writableDataIn = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
|
||||
final int bytesRead = readChannel(writableDataIn);
|
||||
logOperationBytes("Handshake Channel Read", bytesRead);
|
||||
|
||||
if (bytesRead == END_OF_STREAM) {
|
||||
throw getHandshakeException(handshakeStatus, "End of Stream Found");
|
||||
}
|
||||
} else if (unwrapResultStatus == Status.CLOSED) {
|
||||
if (unwrapResult.getStatus() == Status.CLOSED) {
|
||||
throw getHandshakeException(handshakeStatus, "Channel Closed");
|
||||
} else {
|
||||
streamInManager.compact();
|
||||
appDataManager.clear();
|
||||
}
|
||||
streamInManager.compact();
|
||||
appDataManager.clear();
|
||||
break;
|
||||
case NEED_WRAP:
|
||||
final ByteBuffer outboundBuffer = streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
|
||||
|
@ -536,7 +478,7 @@ public class SSLSocketChannel implements Closeable {
|
|||
if (wrapResultStatus == Status.BUFFER_OVERFLOW) {
|
||||
streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
|
||||
} else if (wrapResultStatus == Status.OK) {
|
||||
final ByteBuffer streamBuffer = streamOutManager.prepareForRead(1);
|
||||
final ByteBuffer streamBuffer = streamOutManager.prepareForRead(MINIMUM_READ_BUFFER_SIZE);
|
||||
final int bytesRemaining = streamBuffer.remaining();
|
||||
writeChannel(streamBuffer);
|
||||
logOperationBytes("Handshake Channel Write Completed", bytesRemaining);
|
||||
|
@ -549,8 +491,29 @@ public class SSLSocketChannel implements Closeable {
|
|||
}
|
||||
}
|
||||
|
||||
private int readChannel(final ByteBuffer outputBuffer) throws IOException {
|
||||
private SSLEngineResult unwrapBufferReadChannel() throws IOException {
|
||||
SSLEngineResult unwrapResult = unwrap();
|
||||
|
||||
while (Status.BUFFER_UNDERFLOW == unwrapResult.getStatus()) {
|
||||
final int channelBytesRead = readChannel();
|
||||
if (channelBytesRead == END_OF_STREAM) {
|
||||
throw new EOFException("End of Stream found for Channel Read");
|
||||
}
|
||||
|
||||
unwrapResult = unwrap();
|
||||
if (SSLEngineResult.HandshakeStatus.FINISHED == unwrapResult.getHandshakeStatus()) {
|
||||
// RFC 8446 Section 4.6 describes Post-Handshake Messages for TLS 1.3
|
||||
logOperation("Processing Post-Handshake Messages");
|
||||
unwrapResult = unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
return unwrapResult;
|
||||
}
|
||||
|
||||
private int readChannel() throws IOException {
|
||||
logOperation("Channel Read Started");
|
||||
final ByteBuffer outputBuffer = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
|
||||
|
||||
final long started = System.currentTimeMillis();
|
||||
long sleepNanoseconds = INITIAL_INCREMENTAL_SLEEP;
|
||||
|
@ -568,10 +531,24 @@ public class SSLSocketChannel implements Closeable {
|
|||
continue;
|
||||
}
|
||||
|
||||
logOperationBytes("Channel Read Completed", channelBytesRead);
|
||||
return channelBytesRead;
|
||||
}
|
||||
}
|
||||
|
||||
private void readChannelDiscard() {
|
||||
try {
|
||||
final ByteBuffer readBuffer = ByteBuffer.allocate(DISCARD_BUFFER_LENGTH);
|
||||
int bytesRead = channel.read(readBuffer);
|
||||
while (bytesRead > 0) {
|
||||
readBuffer.clear();
|
||||
bytesRead = channel.read(readBuffer);
|
||||
}
|
||||
} catch (final IOException e) {
|
||||
LOGGER.debug("[{}:{}] Read Channel Discard Failed", remoteAddress, port, e);
|
||||
}
|
||||
}
|
||||
|
||||
private void writeChannel(final ByteBuffer inputBuffer) throws IOException {
|
||||
long lastWriteCompleted = System.currentTimeMillis();
|
||||
|
||||
|
@ -605,19 +582,6 @@ public class SSLSocketChannel implements Closeable {
|
|||
return Math.min(nanoseconds * 2, BUFFER_FULL_EMPTY_WAIT_NANOS);
|
||||
}
|
||||
|
||||
private void readChannelDiscard() {
|
||||
try {
|
||||
final ByteBuffer readBuffer = ByteBuffer.allocate(DISCARD_BUFFER_LENGTH);
|
||||
int bytesRead = channel.read(readBuffer);
|
||||
while (bytesRead > 0) {
|
||||
readBuffer.clear();
|
||||
bytesRead = channel.read(readBuffer);
|
||||
}
|
||||
} catch (final IOException e) {
|
||||
LOGGER.debug("[{}:{}] Read Channel Discard Failed", remoteAddress, port, e);
|
||||
}
|
||||
}
|
||||
|
||||
private int readApplicationBuffer(final byte[] buffer, final int offset, final int len) {
|
||||
logOperationBytes("Application Buffer Read Requested", len);
|
||||
final ByteBuffer appDataBuffer = appDataManager.prepareForRead(len);
|
||||
|
|
|
@ -58,8 +58,4 @@ public class SSLSocketChannelInputStream extends InputStream {
|
|||
public int available() throws IOException {
|
||||
return channel.available();
|
||||
}
|
||||
|
||||
public boolean isDataAvailable() throws IOException {
|
||||
return available() > 0;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -38,9 +38,10 @@ import org.apache.nifi.security.util.SslContextFactory;
|
|||
import org.apache.nifi.security.util.TemporaryKeyStoreBuilder;
|
||||
import org.apache.nifi.security.util.TlsConfiguration;
|
||||
import org.apache.nifi.security.util.TlsPlatform;
|
||||
import org.junit.Assume;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.Timeout;
|
||||
import org.junit.jupiter.api.condition.EnabledIf;
|
||||
|
||||
import javax.net.ssl.SSLContext;
|
||||
import javax.net.ssl.SSLEngine;
|
||||
|
@ -55,17 +56,19 @@ import java.nio.charset.Charset;
|
|||
import java.nio.charset.StandardCharsets;
|
||||
import java.security.GeneralSecurityException;
|
||||
import java.util.concurrent.BlockingQueue;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.Executor;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.LinkedBlockingQueue;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertThrows;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
@Timeout(value = 15)
|
||||
public class SSLSocketChannelTest {
|
||||
private static final String LOCALHOST = "localhost";
|
||||
|
||||
|
@ -81,10 +84,10 @@ public class SSLSocketChannelTest {
|
|||
|
||||
private static final int CHANNEL_POLL_TIMEOUT = 5000;
|
||||
|
||||
private static final long CHANNEL_SLEEP_BEFORE_READ = 100;
|
||||
|
||||
private static final int MAX_MESSAGE_LENGTH = 1024;
|
||||
|
||||
private static final long SHUTDOWN_TIMEOUT = 100;
|
||||
|
||||
private static final String TLS_1_3 = "TLSv1.3";
|
||||
|
||||
private static final String TLS_1_2 = "TLSv1.2";
|
||||
|
@ -97,9 +100,17 @@ public class SSLSocketChannelTest {
|
|||
|
||||
private static final int FIRST_BYTE_OFFSET = 1;
|
||||
|
||||
private static final int SINGLE_COUNT_DOWN = 1;
|
||||
|
||||
private static SSLContext sslContext;
|
||||
|
||||
@BeforeClass
|
||||
private static final String TLS_1_3_SUPPORTED = "isTls13Supported";
|
||||
|
||||
public static boolean isTls13Supported() {
|
||||
return TlsPlatform.getSupportedProtocols().contains(TLS_1_3);
|
||||
}
|
||||
|
||||
@BeforeAll
|
||||
public static void setConfiguration() throws GeneralSecurityException {
|
||||
final TlsConfiguration tlsConfiguration = new TemporaryKeyStoreBuilder().build();
|
||||
sslContext = SslContextFactory.createSslContext(tlsConfiguration);
|
||||
|
@ -115,54 +126,60 @@ public class SSLSocketChannelTest {
|
|||
|
||||
@Test
|
||||
public void testClientConnectHandshakeFailed() throws IOException {
|
||||
assumeProtocolSupported(TLS_1_2);
|
||||
final String enabledProtocol = isTls13Supported() ? TLS_1_3 : TLS_1_2;
|
||||
|
||||
final EventLoopGroup group = new NioEventLoopGroup(GROUP_THREADS);
|
||||
|
||||
try (final SocketChannel socketChannel = SocketChannel.open()) {
|
||||
final int port = NetworkUtils.getAvailableTcpPort();
|
||||
startServer(group, port, TLS_1_2);
|
||||
startServer(group, port, enabledProtocol, getSingleCountDownLatch());
|
||||
|
||||
socketChannel.connect(new InetSocketAddress(LOCALHOST, port));
|
||||
final SSLEngine sslEngine = createSslEngine(TLS_1_2, CLIENT_CHANNEL);
|
||||
final SSLEngine sslEngine = createSslEngine(enabledProtocol, CLIENT_CHANNEL);
|
||||
|
||||
final SSLSocketChannel sslSocketChannel = new SSLSocketChannel(sslEngine, socketChannel);
|
||||
sslSocketChannel.setTimeout(CHANNEL_FAILURE_TIMEOUT);
|
||||
|
||||
group.shutdownGracefully().syncUninterruptibly();
|
||||
shutdownGroup(group);
|
||||
assertThrows(SSLException.class, sslSocketChannel::connect);
|
||||
} finally {
|
||||
group.shutdownGracefully().syncUninterruptibly();
|
||||
shutdownGroup(group);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testClientConnectWriteReadTls12() throws Exception {
|
||||
assumeProtocolSupported(TLS_1_2);
|
||||
assertChannelConnectedWriteReadClosed(TLS_1_2);
|
||||
}
|
||||
|
||||
@EnabledIf(TLS_1_3_SUPPORTED)
|
||||
@Test
|
||||
public void testClientConnectWriteReadTls13() throws Exception {
|
||||
assumeProtocolSupported(TLS_1_3);
|
||||
assertChannelConnectedWriteReadClosed(TLS_1_3);
|
||||
}
|
||||
|
||||
@Test(timeout = CHANNEL_TIMEOUT)
|
||||
@Test
|
||||
public void testClientConnectWriteAvailableReadTls12() throws Exception {
|
||||
assertChannelConnectedWriteAvailableRead(TLS_1_2);
|
||||
}
|
||||
|
||||
@EnabledIf(TLS_1_3_SUPPORTED)
|
||||
@Test
|
||||
public void testClientConnectWriteAvailableReadTls13() throws Exception {
|
||||
assertChannelConnectedWriteAvailableRead(TLS_1_3);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testServerReadWriteTls12() throws Exception {
|
||||
assumeProtocolSupported(TLS_1_2);
|
||||
assertServerChannelConnectedReadClosed(TLS_1_2);
|
||||
}
|
||||
|
||||
@Test(timeout = CHANNEL_TIMEOUT)
|
||||
@EnabledIf(TLS_1_3_SUPPORTED)
|
||||
@Test
|
||||
public void testServerReadWriteTls13() throws Exception {
|
||||
assumeProtocolSupported(TLS_1_3);
|
||||
assertServerChannelConnectedReadClosed(TLS_1_3);
|
||||
}
|
||||
|
||||
private void assumeProtocolSupported(final String protocol) {
|
||||
Assume.assumeTrue(String.format("Protocol [%s] not supported", protocol), TlsPlatform.getSupportedProtocols().contains(protocol));
|
||||
}
|
||||
|
||||
private void assertServerChannelConnectedReadClosed(final String enabledProtocol) throws IOException, InterruptedException {
|
||||
final int port = NetworkUtils.getAvailableTcpPort();
|
||||
final ServerSocketChannel serverSocketChannel = ServerSocketChannel.open();
|
||||
|
@ -194,67 +211,100 @@ public class SSLSocketChannelTest {
|
|||
channel.writeAndFlush(MESSAGE).syncUninterruptibly();
|
||||
|
||||
final String messageRead = queue.poll(CHANNEL_POLL_TIMEOUT, TimeUnit.MILLISECONDS);
|
||||
assertEquals("Message not matched", MESSAGE, messageRead);
|
||||
assertEquals(MESSAGE, messageRead, "Message not matched");
|
||||
} finally {
|
||||
channel.close();
|
||||
}
|
||||
} finally {
|
||||
group.shutdownGracefully().syncUninterruptibly();
|
||||
shutdownGroup(group);
|
||||
serverSocketChannel.close();
|
||||
}
|
||||
}
|
||||
|
||||
private void assertChannelConnectedWriteReadClosed(final String enabledProtocol) throws IOException {
|
||||
processClientSslSocketChannel(enabledProtocol, (sslSocketChannel -> {
|
||||
final CountDownLatch countDownLatch = getSingleCountDownLatch();
|
||||
processClientSslSocketChannel(enabledProtocol, countDownLatch, (sslSocketChannel -> {
|
||||
try {
|
||||
sslSocketChannel.connect();
|
||||
assertFalse("Channel closed", sslSocketChannel.isClosed());
|
||||
assertFalse(sslSocketChannel.isClosed());
|
||||
|
||||
assertChannelWriteRead(sslSocketChannel);
|
||||
assertChannelWriteRead(sslSocketChannel, countDownLatch);
|
||||
|
||||
sslSocketChannel.close();
|
||||
assertTrue("Channel not closed", sslSocketChannel.isClosed());
|
||||
assertTrue(sslSocketChannel.isClosed());
|
||||
} catch (final IOException e) {
|
||||
throw new UncheckedIOException(String.format("Channel Failed for %s", enabledProtocol), e);
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
private void assertChannelWriteRead(final SSLSocketChannel sslSocketChannel) throws IOException {
|
||||
sslSocketChannel.write(MESSAGE_BYTES);
|
||||
|
||||
while (sslSocketChannel.available() == 0) {
|
||||
private void assertChannelConnectedWriteAvailableRead(final String enabledProtocol) throws IOException {
|
||||
final CountDownLatch countDownLatch = getSingleCountDownLatch();
|
||||
processClientSslSocketChannel(enabledProtocol, countDownLatch, (sslSocketChannel -> {
|
||||
try {
|
||||
TimeUnit.MILLISECONDS.sleep(CHANNEL_SLEEP_BEFORE_READ);
|
||||
} catch (final InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
sslSocketChannel.connect();
|
||||
assertFalse(sslSocketChannel.isClosed());
|
||||
|
||||
assertChannelWriteAvailableRead(sslSocketChannel, countDownLatch);
|
||||
|
||||
sslSocketChannel.close();
|
||||
assertTrue(sslSocketChannel.isClosed());
|
||||
} catch (final IOException e) {
|
||||
throw new UncheckedIOException(String.format("Channel Failed for %s", enabledProtocol), e);
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
private void assertChannelWriteAvailableRead(final SSLSocketChannel sslSocketChannel, final CountDownLatch countDownLatch) throws IOException {
|
||||
sslSocketChannel.write(MESSAGE_BYTES);
|
||||
sslSocketChannel.available();
|
||||
awaitCountDownLatch(countDownLatch);
|
||||
assetMessageRead(sslSocketChannel);
|
||||
}
|
||||
|
||||
private void assertChannelWriteRead(final SSLSocketChannel sslSocketChannel, final CountDownLatch countDownLatch) throws IOException {
|
||||
sslSocketChannel.write(MESSAGE_BYTES);
|
||||
awaitCountDownLatch(countDownLatch);
|
||||
assetMessageRead(sslSocketChannel);
|
||||
}
|
||||
|
||||
private void awaitCountDownLatch(final CountDownLatch countDownLatch) throws IOException {
|
||||
try {
|
||||
countDownLatch.await();
|
||||
} catch (final InterruptedException e) {
|
||||
throw new IOException("Count Down Interrupted", e);
|
||||
}
|
||||
}
|
||||
|
||||
private void assetMessageRead(final SSLSocketChannel sslSocketChannel) throws IOException {
|
||||
final byte firstByteRead = (byte) sslSocketChannel.read();
|
||||
assertEquals("Channel Message first byte not matched", MESSAGE_BYTES[0], firstByteRead);
|
||||
assertEquals(MESSAGE_BYTES[0], firstByteRead, "Channel Message first byte not matched");
|
||||
|
||||
final int available = sslSocketChannel.available();
|
||||
final int availableExpected = MESSAGE_BYTES.length - FIRST_BYTE_OFFSET;
|
||||
assertEquals(availableExpected, available, "Available Bytes not matched");
|
||||
|
||||
final byte[] messageBytes = new byte[MESSAGE_BYTES.length];
|
||||
messageBytes[0] = firstByteRead;
|
||||
|
||||
final int messageBytesRead = sslSocketChannel.read(messageBytes, FIRST_BYTE_OFFSET, messageBytes.length);
|
||||
assertEquals("Channel Message Bytes Read not matched", messageBytes.length - FIRST_BYTE_OFFSET, messageBytesRead);
|
||||
assertEquals(messageBytes.length - FIRST_BYTE_OFFSET, messageBytesRead, "Channel Message Bytes Read not matched");
|
||||
|
||||
final String message = new String(messageBytes, MESSAGE_CHARSET);
|
||||
assertEquals("Channel Message not matched", MESSAGE, message);
|
||||
assertEquals(MESSAGE, message, "Message not matched");
|
||||
}
|
||||
|
||||
private void processClientSslSocketChannel(final String enabledProtocol, final Consumer<SSLSocketChannel> channelConsumer) throws IOException {
|
||||
private void processClientSslSocketChannel(final String enabledProtocol, final CountDownLatch countDownLatch, final Consumer<SSLSocketChannel> channelConsumer) throws IOException {
|
||||
final EventLoopGroup group = new NioEventLoopGroup(GROUP_THREADS);
|
||||
|
||||
try {
|
||||
final int port = NetworkUtils.getAvailableTcpPort();
|
||||
startServer(group, port, enabledProtocol);
|
||||
startServer(group, port, enabledProtocol, countDownLatch);
|
||||
final SSLSocketChannel sslSocketChannel = new SSLSocketChannel(sslContext, LOCALHOST, port, null, CLIENT_CHANNEL);
|
||||
sslSocketChannel.setTimeout(CHANNEL_TIMEOUT);
|
||||
channelConsumer.accept(sslSocketChannel);
|
||||
} finally {
|
||||
group.shutdownGracefully().syncUninterruptibly();
|
||||
shutdownGroup(group);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -273,7 +323,7 @@ public class SSLSocketChannelTest {
|
|||
return bootstrap.connect(LOCALHOST, port).syncUninterruptibly().channel();
|
||||
}
|
||||
|
||||
private void startServer(final EventLoopGroup group, final int port, final String enabledProtocol) {
|
||||
private void startServer(final EventLoopGroup group, final int port, final String enabledProtocol, final CountDownLatch countDownLatch) {
|
||||
final ServerBootstrap bootstrap = new ServerBootstrap();
|
||||
bootstrap.group(group);
|
||||
bootstrap.channel(NioServerSocketChannel.class);
|
||||
|
@ -287,6 +337,7 @@ public class SSLSocketChannelTest {
|
|||
@Override
|
||||
protected void channelRead0(ChannelHandlerContext channelHandlerContext, String s) throws Exception {
|
||||
channelHandlerContext.channel().writeAndFlush(MESSAGE).sync();
|
||||
countDownLatch.countDown();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -309,4 +360,12 @@ public class SSLSocketChannelTest {
|
|||
pipeline.addLast(new StringDecoder());
|
||||
pipeline.addLast(new StringEncoder());
|
||||
}
|
||||
|
||||
private void shutdownGroup(final EventLoopGroup group) {
|
||||
group.shutdownGracefully(SHUTDOWN_TIMEOUT, SHUTDOWN_TIMEOUT, TimeUnit.MILLISECONDS).syncUninterruptibly();
|
||||
}
|
||||
|
||||
private CountDownLatch getSingleCountDownLatch() {
|
||||
return new CountDownLatch(SINGLE_COUNT_DOWN);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue