diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/groovy/org/apache/nifi/controller/queue/clustered/server/ConnectionLoadBalanceServerTest.groovy b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/groovy/org/apache/nifi/controller/queue/clustered/server/ConnectionLoadBalanceServerTest.groovy deleted file mode 100644 index 3a1c3039b2..0000000000 --- a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/groovy/org/apache/nifi/controller/queue/clustered/server/ConnectionLoadBalanceServerTest.groovy +++ /dev/null @@ -1,141 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.controller.queue.clustered.server - -import org.apache.nifi.events.EventReporter -import org.apache.nifi.reporting.Severity -import org.apache.nifi.security.util.SslContextFactory -import org.apache.nifi.security.util.TemporaryKeyStoreBuilder -import org.apache.nifi.security.util.TlsConfiguration -import org.junit.After -import org.junit.Before -import org.junit.BeforeClass -import org.junit.Test -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 - -import javax.net.ssl.SSLContext -import javax.net.ssl.SSLPeerUnverifiedException -import javax.net.ssl.SSLServerSocket - -@RunWith(JUnit4.class) -class ConnectionLoadBalanceServerTest extends GroovyTestCase { - private static final String HOSTNAME = "localhost" - private static final int PORT = 54321 - private static final int NUM_THREADS = 1 - private static final int TIMEOUT_MS = 1000 - - private static TlsConfiguration tlsConfiguration - private static SSLContext sslContext - - private ConnectionLoadBalanceServer lbServer - - @BeforeClass - static void setUpOnce() throws Exception { - tlsConfiguration = new TemporaryKeyStoreBuilder().build() - sslContext = SslContextFactory.createSslContext(tlsConfiguration) - } - - @Before - void setUp() { - } - - @After - void tearDown() { - if (lbServer) { - lbServer.stop() - } - } - - @Test - void testRequestPeerListShouldUseTLS() { - // Arrange - SSLContext sslContext = SslContextFactory.createSslContext(tlsConfiguration) - - def mockLBP = [ - receiveFlowFiles: { Socket s, InputStream i, OutputStream o -> null } - ] as LoadBalanceProtocol - def mockER = [:] as EventReporter - - lbServer = new ConnectionLoadBalanceServer(HOSTNAME, PORT, sslContext, NUM_THREADS, mockLBP, mockER, TIMEOUT_MS) - - // Act - lbServer.start() - - // Assert - - // Assert that the actual socket is set correctly due to the override in the LB server - SSLServerSocket socket = lbServer.serverSocket as SSLServerSocket - assert socket.needClientAuth - - // Clean up - lbServer.stop() - } - - @Test - void testShouldHandleSSLPeerUnverifiedException() { - // Arrange - final long testStartMillis = System.currentTimeMillis() - final int CONNECTION_ATTEMPTS = 100 - // If this test takes longer than 3 seconds, it's likely because of external delays, which would invalidate the assertions - final long MAX_TEST_DURATION_MILLIS = 3000 - final String peerDescription = "Test peer" - final SSLPeerUnverifiedException e = new SSLPeerUnverifiedException("Test exception") - - InputStream socketInputStream = new ByteArrayInputStream("This is the socket input stream".bytes) - OutputStream socketOutputStream = new ByteArrayOutputStream() - - Socket mockSocket = [ - getInputStream : { -> socketInputStream }, - getOutputStream: { -> socketOutputStream }, - ] as Socket - LoadBalanceProtocol mockLBProtocol = [ - receiveFlowFiles: { Socket s, InputStream i, OutputStream o -> null } - ] as LoadBalanceProtocol - EventReporter mockER = [ - reportEvent: { Severity s, String c, String m -> } - ] as EventReporter - - def output = [debug: 0, error: 0] - - ConnectionLoadBalanceServer.CommunicateAction communicateAction = new ConnectionLoadBalanceServer.CommunicateAction(mockLBProtocol, mockSocket, mockER) - - // Override the threshold to 100 ms - communicateAction.EXCEPTION_THRESHOLD_MILLIS = 100 - - // Act - CONNECTION_ATTEMPTS.times { int i -> - boolean printedError = communicateAction.handleTlsError(peerDescription, e) - if (printedError) { - output.error++ - } else { - output.debug++ - } - sleep(10) - } - - // Only enforce if the test completed in a reasonable amount of time (i.e. external delays did not influence the timing) - long testStopMillis = System.currentTimeMillis() - long testDurationMillis = testStopMillis - testStartMillis - if (testDurationMillis <= MAX_TEST_DURATION_MILLIS) { - assert output.debug > output.error - } - - // Clean up - communicateAction.stop() - } -} diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/queue/clustered/server/ConnectionLoadBalanceServerTest.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/queue/clustered/server/ConnectionLoadBalanceServerTest.java new file mode 100644 index 0000000000..c16baa3082 --- /dev/null +++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/queue/clustered/server/ConnectionLoadBalanceServerTest.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.controller.queue.clustered.server; + +import org.apache.nifi.events.EventReporter; +import org.apache.nifi.remote.io.socket.NetworkUtils; +import org.apache.nifi.security.util.SslContextFactory; +import org.apache.nifi.security.util.TemporaryKeyStoreBuilder; +import org.apache.nifi.security.util.TlsConfiguration; +import org.apache.nifi.security.util.TlsException; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLSocket; +import javax.net.ssl.SSLSocketFactory; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; +import static org.junit.jupiter.api.Assertions.assertTrue; + +@ExtendWith(MockitoExtension.class) +class ConnectionLoadBalanceServerTest { + + static final String LOCALHOST = "127.0.0.1"; + + static final int SERVER_THREADS = 1; + + static final int SOCKET_TIMEOUT_MILLIS = 30000; + + static final Duration SOCKET_TIMEOUT = Duration.ofMillis(SOCKET_TIMEOUT_MILLIS); + + static SSLContext sslContext; + + @Mock + EventReporter eventReporter; + + @BeforeAll + static void setSslContext() throws TlsException { + final TlsConfiguration tlsConfiguration = new TemporaryKeyStoreBuilder().build(); + sslContext = SslContextFactory.createSslContext(tlsConfiguration); + } + + @Test + void testSslContextSocketHandshakeCompleted() throws IOException { + final ConnectionLoadBalanceServer server = getServer(new SslHandshakeCompletedLoadBalanceProtocol()); + + try { + server.start(); + + final SSLSocketFactory socketFactory = sslContext.getSocketFactory(); + try (final Socket socket = socketFactory.createSocket()) { + final SSLSocket sslSocket = assertSocketConnected(socket, server.getPort()); + assertTimeoutPreemptively(SOCKET_TIMEOUT, sslSocket::startHandshake, "TLS handshake failed"); + } + } finally { + server.stop(); + } + } + + @Test + void testHandshakeCompletedProtocolException() throws IOException { + final ConnectionLoadBalanceServer server = getServer(new ReceiveFlowFilesSslExceptionLoadBalanceProtocol()); + + try { + server.start(); + + final SSLSocketFactory socketFactory = sslContext.getSocketFactory(); + try (final Socket socket = socketFactory.createSocket()) { + final SSLSocket sslSocket = assertSocketConnected(socket, server.getPort()); + assertTimeoutPreemptively(SOCKET_TIMEOUT, sslSocket::startHandshake, "TLS handshake failed"); + } + } finally { + server.stop(); + } + } + + private SSLSocket assertSocketConnected(final Socket socket, final int port) throws IOException { + assertInstanceOf(SSLSocket.class, socket); + + final SSLSocket sslSocket = (SSLSocket) socket; + final InetSocketAddress socketAddress = new InetSocketAddress(LOCALHOST, port); + sslSocket.connect(socketAddress, SOCKET_TIMEOUT_MILLIS); + + assertTrue(sslSocket.isConnected()); + return sslSocket; + } + + private ConnectionLoadBalanceServer getServer(final LoadBalanceProtocol loadBalanceProtocol) { + final int port = NetworkUtils.getAvailableTcpPort(); + return new ConnectionLoadBalanceServer( + LOCALHOST, + port, + sslContext, + SERVER_THREADS, + loadBalanceProtocol, + eventReporter, + SOCKET_TIMEOUT_MILLIS + ); + } + + static class SslHandshakeCompletedLoadBalanceProtocol implements LoadBalanceProtocol { + + @Override + public void receiveFlowFiles(final Socket socket, final InputStream in, final OutputStream out) throws IOException { + final SSLSocket sslSocket = (SSLSocket) socket; + sslSocket.startHandshake(); + } + } + + static class ReceiveFlowFilesSslExceptionLoadBalanceProtocol implements LoadBalanceProtocol { + final AtomicBoolean receiveFlowFilesInvoked = new AtomicBoolean(); + + @Override + public void receiveFlowFiles(final Socket socket, final InputStream in, final OutputStream out) throws IOException { + if (receiveFlowFilesInvoked.get()) { + throw new SSLException(SSLException.class.getSimpleName()); + } else { + final SSLSocket sslSocket = (SSLSocket) socket; + sslSocket.startHandshake(); + } + receiveFlowFilesInvoked.getAndSet(true); + } + } +}