NIFI-10277 This closes #6245. Refactored ConnectionLoadBalanceServerTest

NIFI-10277 Changed failure test to throw SocketException
NIFI-10277 Increased timeout to 30 seconds and moved verify method

Signed-off-by: Joe Witt <joewitt@apache.org>
This commit is contained in:
exceptionfactory 2022-07-25 12:04:21 -05:00 committed by Joe Witt
parent c6999ba9d8
commit 02e37713b3
No known key found for this signature in database
GPG Key ID: 9093BF854F811A1A
2 changed files with 150 additions and 141 deletions

View File

@ -1,141 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.controller.queue.clustered.server
import org.apache.nifi.events.EventReporter
import org.apache.nifi.reporting.Severity
import org.apache.nifi.security.util.SslContextFactory
import org.apache.nifi.security.util.TemporaryKeyStoreBuilder
import org.apache.nifi.security.util.TlsConfiguration
import org.junit.After
import org.junit.Before
import org.junit.BeforeClass
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import javax.net.ssl.SSLContext
import javax.net.ssl.SSLPeerUnverifiedException
import javax.net.ssl.SSLServerSocket
@RunWith(JUnit4.class)
class ConnectionLoadBalanceServerTest extends GroovyTestCase {
private static final String HOSTNAME = "localhost"
private static final int PORT = 54321
private static final int NUM_THREADS = 1
private static final int TIMEOUT_MS = 1000
private static TlsConfiguration tlsConfiguration
private static SSLContext sslContext
private ConnectionLoadBalanceServer lbServer
@BeforeClass
static void setUpOnce() throws Exception {
tlsConfiguration = new TemporaryKeyStoreBuilder().build()
sslContext = SslContextFactory.createSslContext(tlsConfiguration)
}
@Before
void setUp() {
}
@After
void tearDown() {
if (lbServer) {
lbServer.stop()
}
}
@Test
void testRequestPeerListShouldUseTLS() {
// Arrange
SSLContext sslContext = SslContextFactory.createSslContext(tlsConfiguration)
def mockLBP = [
receiveFlowFiles: { Socket s, InputStream i, OutputStream o -> null }
] as LoadBalanceProtocol
def mockER = [:] as EventReporter
lbServer = new ConnectionLoadBalanceServer(HOSTNAME, PORT, sslContext, NUM_THREADS, mockLBP, mockER, TIMEOUT_MS)
// Act
lbServer.start()
// Assert
// Assert that the actual socket is set correctly due to the override in the LB server
SSLServerSocket socket = lbServer.serverSocket as SSLServerSocket
assert socket.needClientAuth
// Clean up
lbServer.stop()
}
@Test
void testShouldHandleSSLPeerUnverifiedException() {
// Arrange
final long testStartMillis = System.currentTimeMillis()
final int CONNECTION_ATTEMPTS = 100
// If this test takes longer than 3 seconds, it's likely because of external delays, which would invalidate the assertions
final long MAX_TEST_DURATION_MILLIS = 3000
final String peerDescription = "Test peer"
final SSLPeerUnverifiedException e = new SSLPeerUnverifiedException("Test exception")
InputStream socketInputStream = new ByteArrayInputStream("This is the socket input stream".bytes)
OutputStream socketOutputStream = new ByteArrayOutputStream()
Socket mockSocket = [
getInputStream : { -> socketInputStream },
getOutputStream: { -> socketOutputStream },
] as Socket
LoadBalanceProtocol mockLBProtocol = [
receiveFlowFiles: { Socket s, InputStream i, OutputStream o -> null }
] as LoadBalanceProtocol
EventReporter mockER = [
reportEvent: { Severity s, String c, String m -> }
] as EventReporter
def output = [debug: 0, error: 0]
ConnectionLoadBalanceServer.CommunicateAction communicateAction = new ConnectionLoadBalanceServer.CommunicateAction(mockLBProtocol, mockSocket, mockER)
// Override the threshold to 100 ms
communicateAction.EXCEPTION_THRESHOLD_MILLIS = 100
// Act
CONNECTION_ATTEMPTS.times { int i ->
boolean printedError = communicateAction.handleTlsError(peerDescription, e)
if (printedError) {
output.error++
} else {
output.debug++
}
sleep(10)
}
// Only enforce if the test completed in a reasonable amount of time (i.e. external delays did not influence the timing)
long testStopMillis = System.currentTimeMillis()
long testDurationMillis = testStopMillis - testStartMillis
if (testDurationMillis <= MAX_TEST_DURATION_MILLIS) {
assert output.debug > output.error
}
// Clean up
communicateAction.stop()
}
}

View File

@ -0,0 +1,150 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.controller.queue.clustered.server;
import org.apache.nifi.events.EventReporter;
import org.apache.nifi.remote.io.socket.NetworkUtils;
import org.apache.nifi.security.util.SslContextFactory;
import org.apache.nifi.security.util.TemporaryKeyStoreBuilder;
import org.apache.nifi.security.util.TlsConfiguration;
import org.apache.nifi.security.util.TlsException;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.time.Duration;
import java.util.concurrent.atomic.AtomicBoolean;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively;
import static org.junit.jupiter.api.Assertions.assertTrue;
@ExtendWith(MockitoExtension.class)
class ConnectionLoadBalanceServerTest {
static final String LOCALHOST = "127.0.0.1";
static final int SERVER_THREADS = 1;
static final int SOCKET_TIMEOUT_MILLIS = 30000;
static final Duration SOCKET_TIMEOUT = Duration.ofMillis(SOCKET_TIMEOUT_MILLIS);
static SSLContext sslContext;
@Mock
EventReporter eventReporter;
@BeforeAll
static void setSslContext() throws TlsException {
final TlsConfiguration tlsConfiguration = new TemporaryKeyStoreBuilder().build();
sslContext = SslContextFactory.createSslContext(tlsConfiguration);
}
@Test
void testSslContextSocketHandshakeCompleted() throws IOException {
final ConnectionLoadBalanceServer server = getServer(new SslHandshakeCompletedLoadBalanceProtocol());
try {
server.start();
final SSLSocketFactory socketFactory = sslContext.getSocketFactory();
try (final Socket socket = socketFactory.createSocket()) {
final SSLSocket sslSocket = assertSocketConnected(socket, server.getPort());
assertTimeoutPreemptively(SOCKET_TIMEOUT, sslSocket::startHandshake, "TLS handshake failed");
}
} finally {
server.stop();
}
}
@Test
void testHandshakeCompletedProtocolException() throws IOException {
final ConnectionLoadBalanceServer server = getServer(new ReceiveFlowFilesSslExceptionLoadBalanceProtocol());
try {
server.start();
final SSLSocketFactory socketFactory = sslContext.getSocketFactory();
try (final Socket socket = socketFactory.createSocket()) {
final SSLSocket sslSocket = assertSocketConnected(socket, server.getPort());
assertTimeoutPreemptively(SOCKET_TIMEOUT, sslSocket::startHandshake, "TLS handshake failed");
}
} finally {
server.stop();
}
}
private SSLSocket assertSocketConnected(final Socket socket, final int port) throws IOException {
assertInstanceOf(SSLSocket.class, socket);
final SSLSocket sslSocket = (SSLSocket) socket;
final InetSocketAddress socketAddress = new InetSocketAddress(LOCALHOST, port);
sslSocket.connect(socketAddress, SOCKET_TIMEOUT_MILLIS);
assertTrue(sslSocket.isConnected());
return sslSocket;
}
private ConnectionLoadBalanceServer getServer(final LoadBalanceProtocol loadBalanceProtocol) {
final int port = NetworkUtils.getAvailableTcpPort();
return new ConnectionLoadBalanceServer(
LOCALHOST,
port,
sslContext,
SERVER_THREADS,
loadBalanceProtocol,
eventReporter,
SOCKET_TIMEOUT_MILLIS
);
}
static class SslHandshakeCompletedLoadBalanceProtocol implements LoadBalanceProtocol {
@Override
public void receiveFlowFiles(final Socket socket, final InputStream in, final OutputStream out) throws IOException {
final SSLSocket sslSocket = (SSLSocket) socket;
sslSocket.startHandshake();
}
}
static class ReceiveFlowFilesSslExceptionLoadBalanceProtocol implements LoadBalanceProtocol {
final AtomicBoolean receiveFlowFilesInvoked = new AtomicBoolean();
@Override
public void receiveFlowFiles(final Socket socket, final InputStream in, final OutputStream out) throws IOException {
if (receiveFlowFilesInvoked.get()) {
throw new SSLException(SSLException.class.getSimpleName());
} else {
final SSLSocket sslSocket = (SSLSocket) socket;
sslSocket.startHandshake();
}
receiveFlowFilesInvoked.getAndSet(true);
}
}
}