mirror of https://github.com/apache/nifi.git
NIFI-9761 Correct PeerChannel processing for TLS 1.3 (#5836)
* NIFI-9761 Corrected PeerChannel processing for TLS 1.3 - Added TestPeerChannel with methods for TLS 1.2 and TLS 1.3 - Updated PeerChannel.close() to process SSLEngine close notification - Improved logging and corrected handling after decryption
This commit is contained in:
parent
6a1c7c72d5
commit
c73573b325
|
@ -264,6 +264,11 @@
|
|||
<version>${nifi.groovy.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>io.netty</groupId>
|
||||
<artifactId>netty-handler</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.github.ben-manes.caffeine</groupId>
|
||||
<artifactId>caffeine</artifactId>
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.slf4j.LoggerFactory;
|
|||
import javax.net.ssl.SSLEngine;
|
||||
import javax.net.ssl.SSLEngineResult;
|
||||
import javax.net.ssl.SSLException;
|
||||
import javax.net.ssl.SSLSession;
|
||||
import java.io.Closeable;
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
|
@ -30,6 +31,10 @@ import java.nio.channels.SocketChannel;
|
|||
import java.util.OptionalInt;
|
||||
|
||||
public class PeerChannel implements Closeable {
|
||||
private static final int END_OF_FILE = -1;
|
||||
|
||||
private static final int EMPTY_BUFFER = 0;
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(PeerChannel.class);
|
||||
|
||||
private final SocketChannel socketChannel;
|
||||
|
@ -38,7 +43,7 @@ public class PeerChannel implements Closeable {
|
|||
|
||||
private final ByteBuffer singleByteBuffer = ByteBuffer.allocate(1);
|
||||
private ByteBuffer destinationBuffer = ByteBuffer.allocate(16 * 1024); // buffer that SSLEngine is to write into
|
||||
private ByteBuffer streamBuffer = ByteBuffer.allocate(16 * 1024); // buffer for data that is read from SocketChannel
|
||||
private final ByteBuffer streamBuffer = ByteBuffer.allocate(16 * 1024); // buffer for data that is read from SocketChannel
|
||||
private ByteBuffer applicationBuffer = ByteBuffer.allocate(0); // buffer for application-level data that is ready to be served up (i.e., already decrypted if necessary)
|
||||
|
||||
public PeerChannel(final SocketChannel socketChannel, final SSLEngine sslEngine, final String peerDescription) {
|
||||
|
@ -47,10 +52,45 @@ public class PeerChannel implements Closeable {
|
|||
this.peerDescription = peerDescription;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Close Socket Channel and process SSLEngine close notifications when configured
|
||||
*
|
||||
* @throws IOException Thrown on failure to close Socket Channel or process SSLEngine operations
|
||||
*/
|
||||
@Override
|
||||
public void close() throws IOException {
|
||||
socketChannel.close();
|
||||
try {
|
||||
if (sslEngine == null) {
|
||||
logger.debug("Closing Peer Channel [{}] SSLEngine not configured", peerDescription);
|
||||
} else {
|
||||
logger.debug("Closing Peer Channel [{}] SSLEngine close started", peerDescription);
|
||||
sslEngine.closeOutbound();
|
||||
|
||||
// Send TLS close notification packets available after initiating SSLEngine.closeOutbound()
|
||||
final ByteBuffer inputBuffer = ByteBuffer.allocate(0);
|
||||
final ByteBuffer outputBuffer = ByteBuffer.allocate(sslEngine.getSession().getPacketBufferSize());
|
||||
|
||||
SSLEngineResult wrapResult = sslEngine.wrap(inputBuffer, outputBuffer);
|
||||
SSLEngineResult.Status status = wrapResult.getStatus();
|
||||
outputBuffer.flip();
|
||||
if (SSLEngineResult.Status.OK == status) {
|
||||
write(outputBuffer);
|
||||
outputBuffer.clear();
|
||||
wrapResult = sslEngine.wrap(inputBuffer, outputBuffer);
|
||||
status = wrapResult.getStatus();
|
||||
}
|
||||
if (SSLEngineResult.Status.CLOSED == status) {
|
||||
write(outputBuffer);
|
||||
} else {
|
||||
throw new SSLException(String.format("Closing Peer Channel [%s] Invalid Wrap Result Status [%s]", peerDescription, status));
|
||||
}
|
||||
|
||||
logger.debug("Closing Peer Channel [{}] SSLEngine close completed", peerDescription);
|
||||
}
|
||||
} finally {
|
||||
logger.debug("Closing Peer Channel [{}] Socket Channel close started", peerDescription);
|
||||
socketChannel.close();
|
||||
}
|
||||
}
|
||||
|
||||
public boolean isConnected() {
|
||||
|
@ -65,6 +105,13 @@ public class PeerChannel implements Closeable {
|
|||
return peerDescription;
|
||||
}
|
||||
|
||||
/**
|
||||
* Write one byte to the channel
|
||||
*
|
||||
* @param b Byte to be written
|
||||
* @return Status of write operation returns true on success
|
||||
* @throws IOException Thrown on failure to write to the Socket Channel
|
||||
*/
|
||||
public boolean write(final byte b) throws IOException {
|
||||
singleByteBuffer.clear();
|
||||
singleByteBuffer.put(b);
|
||||
|
@ -75,13 +122,18 @@ public class PeerChannel implements Closeable {
|
|||
return bytesWritten > 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Read one byte as an unsigned integer from the channel
|
||||
*
|
||||
* @return Returns empty when zero bytes are available and returns negative one when the channel is closed
|
||||
* @throws IOException Thrown on failure to read from Socket Channel
|
||||
*/
|
||||
public OptionalInt read() throws IOException {
|
||||
singleByteBuffer.clear();
|
||||
final int bytesRead = read(singleByteBuffer);
|
||||
if (bytesRead < 0) {
|
||||
return OptionalInt.of(-1);
|
||||
}
|
||||
if (bytesRead == 0) {
|
||||
return OptionalInt.of(END_OF_FILE);
|
||||
} else if (bytesRead == EMPTY_BUFFER) {
|
||||
return OptionalInt.empty();
|
||||
}
|
||||
|
||||
|
@ -91,9 +143,6 @@ public class PeerChannel implements Closeable {
|
|||
return OptionalInt.of(read & 0xFF);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Reads the given ByteBuffer of data and returns a new ByteBuffer (which is "flipped" / ready to be read). The newly returned
|
||||
* ByteBuffer will be written to be written via the {@link #write(ByteBuffer)} method. I.e., it will have already been encrypted, if
|
||||
|
@ -104,11 +153,11 @@ public class PeerChannel implements Closeable {
|
|||
* @throws IOException if a failure occurs while encrypting the data
|
||||
*/
|
||||
public ByteBuffer prepareForWrite(final ByteBuffer plaintext) throws IOException {
|
||||
logger.trace("Channel [{}] Buffer wrap started: Input Bytes [{}]", peerDescription, plaintext.remaining());
|
||||
if (sslEngine == null) {
|
||||
return plaintext;
|
||||
}
|
||||
|
||||
|
||||
ByteBuffer prepared = ByteBuffer.allocate(Math.min(85, plaintext.capacity() - plaintext.position()));
|
||||
while (plaintext.hasRemaining()) {
|
||||
encrypt(plaintext);
|
||||
|
@ -125,14 +174,28 @@ public class PeerChannel implements Closeable {
|
|||
}
|
||||
|
||||
prepared.flip();
|
||||
logger.trace("Channel [{}] Buffer wrap completed: Prepared Bytes [{}]", peerDescription, prepared.remaining());
|
||||
return prepared;
|
||||
}
|
||||
|
||||
/**
|
||||
* Write prepared buffer to Socket Channel
|
||||
*
|
||||
* @param preparedBuffer Buffer must contain bytes processed through prepareForWrite() when TLS is enabled
|
||||
* @return Number of bytes written according to SocketChannel.write()
|
||||
* @throws IOException Thrown on failure to write to the Socket Channel
|
||||
*/
|
||||
public int write(final ByteBuffer preparedBuffer) throws IOException {
|
||||
return socketChannel.write(preparedBuffer);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Read application data bytes into the provided buffer
|
||||
*
|
||||
* @param dst Buffer to be populated with application data bytes
|
||||
* @return Number of bytes read into the provided buffer
|
||||
* @throws IOException Thrown on failure to read from the Socket Channel
|
||||
*/
|
||||
public int read(final ByteBuffer dst) throws IOException {
|
||||
// If we have data ready to go, then go ahead and copy it.
|
||||
final int bytesCopied = copy(applicationBuffer, dst);
|
||||
|
@ -141,12 +204,11 @@ public class PeerChannel implements Closeable {
|
|||
}
|
||||
|
||||
final int bytesRead = socketChannel.read(streamBuffer);
|
||||
if (bytesRead < 1) {
|
||||
return bytesRead;
|
||||
}
|
||||
|
||||
if (bytesRead > 0) {
|
||||
logger.trace("Read {} bytes from SocketChannel", bytesRead);
|
||||
logger.trace("Channel [{}] Socket read completed: bytes [{}]", peerDescription, bytesRead);
|
||||
if (bytesRead == END_OF_FILE) {
|
||||
return END_OF_FILE;
|
||||
} else if (streamBuffer.remaining() == EMPTY_BUFFER) {
|
||||
return EMPTY_BUFFER;
|
||||
}
|
||||
|
||||
streamBuffer.flip();
|
||||
|
@ -157,7 +219,7 @@ public class PeerChannel implements Closeable {
|
|||
return copy(applicationBuffer, dst);
|
||||
} else {
|
||||
final boolean decrypted = decrypt(streamBuffer);
|
||||
logger.trace("Decryption after reading those bytes successful = {}", decrypted);
|
||||
logger.trace("Channel [{}] Decryption completed [{}]", peerDescription, decrypted);
|
||||
|
||||
if (decrypted) {
|
||||
cloneToApplicationBuffer(destinationBuffer);
|
||||
|
@ -167,9 +229,8 @@ public class PeerChannel implements Closeable {
|
|||
} else {
|
||||
// Not enough data to decrypt. Compact the buffer so that we keep the data we have
|
||||
// but prepare the buffer to be written to again.
|
||||
logger.debug("Not enough data to decrypt. Will need to consume more data before decrypting");
|
||||
streamBuffer.compact();
|
||||
return 0;
|
||||
logger.trace("Channel [{}] Socket Channel read required", peerDescription);
|
||||
return EMPTY_BUFFER;
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
|
@ -219,6 +280,7 @@ public class PeerChannel implements Closeable {
|
|||
|
||||
while (true) {
|
||||
final SSLEngineResult result = sslEngine.wrap(plaintext, destinationBuffer);
|
||||
logOperationResult("WRAP", result);
|
||||
|
||||
switch (result.getStatus()) {
|
||||
case OK:
|
||||
|
@ -240,9 +302,6 @@ public class PeerChannel implements Closeable {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Attempts to decrypt the given buffer of data, writing the result into {@link #destinationBuffer}. If successful, will return <code>true</code>.
|
||||
* If more data is needed in order to perform the decryption, will return <code>false</code>.
|
||||
|
@ -260,15 +319,21 @@ public class PeerChannel implements Closeable {
|
|||
|
||||
while (true) {
|
||||
final SSLEngineResult result = sslEngine.unwrap(encrypted, destinationBuffer);
|
||||
logOperationResult("UNWRAP", result);
|
||||
|
||||
switch (result.getStatus()) {
|
||||
case OK:
|
||||
if (SSLEngineResult.HandshakeStatus.FINISHED == result.getHandshakeStatus()) {
|
||||
// RFC 8446 Section 4.6 describes Post-Handshake Messages for TLS 1.3
|
||||
// Break out of switch statement to call SSLEngine.unwrap() again
|
||||
break;
|
||||
}
|
||||
destinationBuffer.flip();
|
||||
return true;
|
||||
case CLOSED:
|
||||
throw new IOException("Failed to decrypt data from Peer " + peerDescription + " because Peer unexpectedly closed connection");
|
||||
case BUFFER_OVERFLOW:
|
||||
// ecnryptedBuffer is not large enough. Need to increase the size.
|
||||
// encryptedBuffer is not large enough. Need to increase the size.
|
||||
final ByteBuffer tempBuffer = ByteBuffer.allocate(encrypted.position() + sslEngine.getSession().getApplicationBufferSize());
|
||||
destinationBuffer.flip();
|
||||
tempBuffer.put(destinationBuffer);
|
||||
|
@ -282,7 +347,11 @@ public class PeerChannel implements Closeable {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Perform TLS handshake when SSLEngine configured
|
||||
*
|
||||
* @throws IOException Thrown on failure to handle socket communication or TLS packet processing
|
||||
*/
|
||||
public void performHandshake() throws IOException {
|
||||
if (sslEngine == null) {
|
||||
return;
|
||||
|
@ -295,18 +364,17 @@ public class PeerChannel implements Closeable {
|
|||
|
||||
while (true) {
|
||||
final SSLEngineResult.HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus();
|
||||
logHandshakeStatus(handshakeStatus);
|
||||
|
||||
switch (handshakeStatus) {
|
||||
case FINISHED:
|
||||
case NOT_HANDSHAKING:
|
||||
streamBuffer.clear();
|
||||
destinationBuffer.clear();
|
||||
logger.debug("Completed SSL Handshake with Peer {}", peerDescription);
|
||||
logHandshakeCompleted();
|
||||
return;
|
||||
|
||||
case NEED_TASK:
|
||||
logger.debug("SSL Handshake with Peer {} Needs Task", peerDescription);
|
||||
|
||||
Runnable runnable;
|
||||
while ((runnable = sslEngine.getDelegatedTask()) != null) {
|
||||
runnable.run();
|
||||
|
@ -314,27 +382,22 @@ public class PeerChannel implements Closeable {
|
|||
break;
|
||||
|
||||
case NEED_WRAP:
|
||||
logger.trace("SSL Handshake with Peer {} Needs Wrap", peerDescription);
|
||||
|
||||
encrypt(emptyMessage);
|
||||
final int bytesWritten = write(destinationBuffer);
|
||||
logger.debug("Wrote {} bytes for NEED_WRAP portion of Handshake", bytesWritten);
|
||||
logHandshakeStatusBytes(handshakeStatus, "Socket write completed", bytesWritten);
|
||||
break;
|
||||
|
||||
case NEED_UNWRAP:
|
||||
logger.trace("SSL Handshake with Peer {} Needs Unwrap", peerDescription);
|
||||
|
||||
while (sslEngine.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
|
||||
final boolean decrypted = decrypt(unwrapBuffer);
|
||||
if (decrypted) {
|
||||
logger.trace("Decryption was successful for NEED_UNWRAP portion of Handshake");
|
||||
final SSLEngineResult.HandshakeStatus unwrapHandshakeStatus = sslEngine.getHandshakeStatus();
|
||||
if (decrypted || unwrapHandshakeStatus == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) {
|
||||
logHandshakeStatus(unwrapHandshakeStatus, "Decryption completed");
|
||||
break;
|
||||
}
|
||||
|
||||
if (unwrapBuffer.capacity() - unwrapBuffer.position() < 1) {
|
||||
logger.trace("Enlarging size of Buffer for NEED_UNWRAP portion of Handshake");
|
||||
|
||||
// destinationBuffer is not large enough. Need to increase the size.
|
||||
logHandshakeStatus(unwrapHandshakeStatus, "Increasing unwrap buffer for decryption");
|
||||
final ByteBuffer tempBuffer = ByteBuffer.allocate(unwrapBuffer.capacity() + sslEngine.getSession().getApplicationBufferSize());
|
||||
tempBuffer.put(unwrapBuffer);
|
||||
unwrapBuffer = tempBuffer;
|
||||
|
@ -342,17 +405,37 @@ public class PeerChannel implements Closeable {
|
|||
continue;
|
||||
}
|
||||
|
||||
logger.trace("Need to read more bytes for NEED_UNWRAP portion of Handshake");
|
||||
|
||||
// Need to read more data.
|
||||
logHandshakeStatus(unwrapHandshakeStatus, "Socket read started");
|
||||
unwrapBuffer.compact();
|
||||
final int bytesRead = socketChannel.read(unwrapBuffer);
|
||||
unwrapBuffer.flip();
|
||||
logger.debug("Read {} bytes for NEED_UNWRAP portion of Handshake", bytesRead);
|
||||
|
||||
logHandshakeStatusBytes(unwrapHandshakeStatus, "Socket read completed", bytesRead);
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void logOperationResult(final String operation, final SSLEngineResult sslEngineResult) {
|
||||
logger.trace("Channel [{}] {} [{}]", peerDescription, operation, sslEngineResult);
|
||||
}
|
||||
|
||||
private void logHandshakeCompleted() {
|
||||
final SSLSession sslSession = sslEngine.getSession();
|
||||
logger.debug("Channel [{}] Handshake Completed Protocol [{}] Cipher Suite [{}]", peerDescription, sslSession.getProtocol(), sslSession.getCipherSuite());
|
||||
}
|
||||
|
||||
private void logHandshakeStatus(final SSLEngineResult.HandshakeStatus handshakeStatus) {
|
||||
logger.debug("Channel [{}] Handshake Status [{}]", peerDescription, handshakeStatus);
|
||||
}
|
||||
|
||||
private void logHandshakeStatus(final SSLEngineResult.HandshakeStatus handshakeStatus, final String operation) {
|
||||
logger.debug("Channel [{}] Handshake Status [{}] {}", peerDescription, handshakeStatus, operation);
|
||||
}
|
||||
|
||||
private void logHandshakeStatusBytes(final SSLEngineResult.HandshakeStatus handshakeStatus, final String operation, final int bytes) {
|
||||
logger.debug("Channel [{}] Handshake Status [{}] {} Bytes [{}]", peerDescription, handshakeStatus, operation, bytes);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,254 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.nifi.controller.queue.clustered.client.async.nio;
|
||||
|
||||
import io.netty.bootstrap.ServerBootstrap;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.Channel;
|
||||
import io.netty.channel.ChannelFuture;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInitializer;
|
||||
import io.netty.channel.ChannelPipeline;
|
||||
import io.netty.channel.EventLoopGroup;
|
||||
import io.netty.channel.SimpleChannelInboundHandler;
|
||||
import io.netty.channel.nio.NioEventLoopGroup;
|
||||
import io.netty.channel.socket.nio.NioServerSocketChannel;
|
||||
import io.netty.handler.ssl.SslHandler;
|
||||
import org.apache.nifi.remote.io.socket.NetworkUtils;
|
||||
import org.apache.nifi.security.util.SslContextFactory;
|
||||
import org.apache.nifi.security.util.TemporaryKeyStoreBuilder;
|
||||
import org.apache.nifi.security.util.TlsConfiguration;
|
||||
import org.apache.nifi.security.util.TlsPlatform;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.Timeout;
|
||||
import org.junit.jupiter.api.condition.EnabledIf;
|
||||
|
||||
import javax.net.ssl.SSLContext;
|
||||
import javax.net.ssl.SSLEngine;
|
||||
import java.io.IOException;
|
||||
import java.io.UncheckedIOException;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.net.Socket;
|
||||
import java.net.SocketException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.channels.SocketChannel;
|
||||
import java.security.GeneralSecurityException;
|
||||
import java.util.OptionalInt;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
public class TestPeerChannel {
|
||||
private static final String LOCALHOST = "localhost";
|
||||
|
||||
private static final int GROUP_THREADS = 1;
|
||||
|
||||
private static final boolean CLIENT_CHANNEL = true;
|
||||
|
||||
private static final boolean SERVER_CHANNEL = false;
|
||||
|
||||
private static final long READ_SLEEP_INTERVAL = 500;
|
||||
|
||||
private static final int CHANNEL_TIMEOUT = 15000;
|
||||
|
||||
private static final int SOCKET_TIMEOUT = 5000;
|
||||
|
||||
private static final long SHUTDOWN_TIMEOUT = 100;
|
||||
|
||||
private static final String TLS_1_3 = "TLSv1.3";
|
||||
|
||||
private static final String TLS_1_2 = "TLSv1.2";
|
||||
|
||||
private static final String TLS_1_3_SUPPORTED = "isTls13Supported";
|
||||
|
||||
private static final int PROTOCOL_VERSION = 1;
|
||||
|
||||
private static final int VERSION_ACCEPTED = 0x10;
|
||||
|
||||
private static SSLContext sslContext;
|
||||
|
||||
public static boolean isTls13Supported() {
|
||||
return TlsPlatform.getSupportedProtocols().contains(TLS_1_3);
|
||||
}
|
||||
|
||||
@BeforeAll
|
||||
public static void setConfiguration() throws GeneralSecurityException {
|
||||
final TlsConfiguration tlsConfiguration = new TemporaryKeyStoreBuilder().build();
|
||||
sslContext = SslContextFactory.createSslContext(tlsConfiguration);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Timeout(value = CHANNEL_TIMEOUT, unit = TimeUnit.MILLISECONDS)
|
||||
public void testConnectedClose() throws IOException {
|
||||
final String enabledProtocol = getEnabledProtocol();
|
||||
|
||||
processChannel(enabledProtocol, peerChannel -> {});
|
||||
}
|
||||
|
||||
@Test
|
||||
@Timeout(value = CHANNEL_TIMEOUT, unit = TimeUnit.MILLISECONDS)
|
||||
public void testConnectedWriteReadCloseTls12() throws IOException {
|
||||
assertWriteReadSuccess(TLS_1_2);
|
||||
}
|
||||
|
||||
@EnabledIf(TLS_1_3_SUPPORTED)
|
||||
@Test
|
||||
@Timeout(value = CHANNEL_TIMEOUT, unit = TimeUnit.MILLISECONDS)
|
||||
public void testConnectedWriteReadCloseTls13() throws IOException {
|
||||
assertWriteReadSuccess(TLS_1_3);
|
||||
}
|
||||
|
||||
private void assertWriteReadSuccess(final String enabledProtocol) throws IOException {
|
||||
processChannel(enabledProtocol, peerChannel -> {
|
||||
try {
|
||||
peerChannel.performHandshake();
|
||||
|
||||
final byte[] version = new byte[]{PROTOCOL_VERSION};
|
||||
final ByteBuffer versionBuffer = ByteBuffer.wrap(version);
|
||||
final ByteBuffer encryptedVersionBuffer = peerChannel.prepareForWrite(versionBuffer);
|
||||
peerChannel.write(encryptedVersionBuffer);
|
||||
|
||||
final int firstByteRead = read(peerChannel);
|
||||
assertEquals(PROTOCOL_VERSION, firstByteRead, "Peer Channel first byte read not matched");
|
||||
|
||||
final byte[] versionAccepted = new byte[]{VERSION_ACCEPTED};
|
||||
final ByteBuffer versionAcceptedBuffer = ByteBuffer.wrap(versionAccepted);
|
||||
final ByteBuffer encryptedVersionAcceptedBuffer = peerChannel.prepareForWrite(versionAcceptedBuffer);
|
||||
peerChannel.write(encryptedVersionAcceptedBuffer);
|
||||
|
||||
final int secondByteRead = read(peerChannel);
|
||||
assertEquals(VERSION_ACCEPTED, secondByteRead, "Peer Channel second byte read not matched");
|
||||
} catch (final IOException e) {
|
||||
throw new UncheckedIOException(String.format("Channel Failed for %s", enabledProtocol), e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private int read(final PeerChannel peerChannel) throws IOException {
|
||||
OptionalInt read = peerChannel.read();
|
||||
while (!read.isPresent()) {
|
||||
try {
|
||||
TimeUnit.MILLISECONDS.sleep(READ_SLEEP_INTERVAL);
|
||||
} catch (InterruptedException e) {
|
||||
throw new RuntimeException("Peer Channel read sleep interrupted", e);
|
||||
}
|
||||
read = peerChannel.read();
|
||||
}
|
||||
return read.getAsInt();
|
||||
}
|
||||
|
||||
private void processChannel(final String enabledProtocol, final Consumer<PeerChannel> channelConsumer) throws IOException {
|
||||
final EventLoopGroup group = new NioEventLoopGroup(GROUP_THREADS);
|
||||
|
||||
try (final SocketChannel socketChannel = SocketChannel.open()) {
|
||||
final Socket socket = socketChannel.socket();
|
||||
socket.setSoTimeout(SOCKET_TIMEOUT);
|
||||
|
||||
final InetSocketAddress serverSocketAddress = getServerSocketAddress();
|
||||
startServer(group, serverSocketAddress.getPort(), enabledProtocol);
|
||||
|
||||
socketChannel.connect(serverSocketAddress);
|
||||
final SSLEngine sslEngine = createSslEngine(enabledProtocol, CLIENT_CHANNEL);
|
||||
|
||||
final PeerChannel peerChannel = new PeerChannel(socketChannel, sslEngine, serverSocketAddress.toString());
|
||||
assertConnectedOpen(peerChannel);
|
||||
|
||||
socketChannel.configureBlocking(false);
|
||||
channelConsumer.accept(peerChannel);
|
||||
|
||||
peerChannel.close();
|
||||
assertNotConnectedNotOpen(peerChannel);
|
||||
} finally {
|
||||
shutdownGroup(group);
|
||||
}
|
||||
}
|
||||
|
||||
private void assertConnectedOpen(final PeerChannel peerChannel) {
|
||||
assertTrue(peerChannel.isConnected(), "Channel not connected");
|
||||
assertTrue(peerChannel.isOpen(), "Channel not open");
|
||||
}
|
||||
|
||||
private void assertNotConnectedNotOpen(final PeerChannel peerChannel) {
|
||||
assertFalse(peerChannel.isConnected(), "Channel connected");
|
||||
assertFalse(peerChannel.isOpen(), "Channel open");
|
||||
}
|
||||
|
||||
private void startServer(final EventLoopGroup group, final int port, final String enabledProtocol) {
|
||||
final ServerBootstrap bootstrap = new ServerBootstrap();
|
||||
bootstrap.group(group);
|
||||
bootstrap.channel(NioServerSocketChannel.class);
|
||||
bootstrap.childHandler(new ChannelInitializer<Channel>() {
|
||||
@Override
|
||||
protected void initChannel(final Channel channel) {
|
||||
final ChannelPipeline pipeline = channel.pipeline();
|
||||
final SSLEngine sslEngine = createSslEngine(enabledProtocol, SERVER_CHANNEL);
|
||||
setPipelineHandlers(pipeline, sslEngine);
|
||||
pipeline.addLast(new SimpleChannelInboundHandler<ByteBuf>() {
|
||||
private int protocolVersion;
|
||||
|
||||
@Override
|
||||
protected void channelRead0(ChannelHandlerContext channelHandlerContext, ByteBuf byteBuf) throws Exception {
|
||||
if (byteBuf.readableBytes() == 1) {
|
||||
final int read = byteBuf.readByte();
|
||||
if (PROTOCOL_VERSION == read) {
|
||||
protocolVersion = read;
|
||||
channelHandlerContext.writeAndFlush(Unpooled.wrappedBuffer(new byte[]{PROTOCOL_VERSION}));
|
||||
} else if (protocolVersion == PROTOCOL_VERSION) {
|
||||
channelHandlerContext.writeAndFlush(Unpooled.wrappedBuffer(new byte[]{VERSION_ACCEPTED}));
|
||||
} else {
|
||||
throw new SocketException(String.format("Unexpected Integer [%d] read", read));
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
final ChannelFuture bindFuture = bootstrap.bind(LOCALHOST, port);
|
||||
bindFuture.syncUninterruptibly();
|
||||
}
|
||||
|
||||
private SSLEngine createSslEngine(final String enabledProtocol, final boolean useClientMode) {
|
||||
final SSLEngine sslEngine = sslContext.createSSLEngine();
|
||||
sslEngine.setUseClientMode(useClientMode);
|
||||
sslEngine.setEnabledProtocols(new String[]{enabledProtocol});
|
||||
return sslEngine;
|
||||
}
|
||||
|
||||
|
||||
private void setPipelineHandlers(final ChannelPipeline pipeline, final SSLEngine sslEngine) {
|
||||
pipeline.addLast(new SslHandler(sslEngine));
|
||||
}
|
||||
|
||||
private void shutdownGroup(final EventLoopGroup group) {
|
||||
group.shutdownGracefully(SHUTDOWN_TIMEOUT, SHUTDOWN_TIMEOUT, TimeUnit.MILLISECONDS).syncUninterruptibly();
|
||||
}
|
||||
|
||||
private InetSocketAddress getServerSocketAddress() {
|
||||
final int port = NetworkUtils.getAvailableTcpPort();
|
||||
return new InetSocketAddress(LOCALHOST, port);
|
||||
}
|
||||
|
||||
private String getEnabledProtocol() {
|
||||
return isTls13Supported() ? TLS_1_3 : TLS_1_2;
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue