mirror of https://github.com/apache/nifi.git
NIFI-10277 This closes #6245. Refactored ConnectionLoadBalanceServerTest
NIFI-10277 Changed failure test to throw SocketException NIFI-10277 Increased timeout to 30 seconds and moved verify method Signed-off-by: Joe Witt <joewitt@apache.org>
This commit is contained in:
parent
c6999ba9d8
commit
02e37713b3
|
@ -1,141 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.nifi.controller.queue.clustered.server
|
||||
|
||||
import org.apache.nifi.events.EventReporter
|
||||
import org.apache.nifi.reporting.Severity
|
||||
import org.apache.nifi.security.util.SslContextFactory
|
||||
import org.apache.nifi.security.util.TemporaryKeyStoreBuilder
|
||||
import org.apache.nifi.security.util.TlsConfiguration
|
||||
import org.junit.After
|
||||
import org.junit.Before
|
||||
import org.junit.BeforeClass
|
||||
import org.junit.Test
|
||||
import org.junit.runner.RunWith
|
||||
import org.junit.runners.JUnit4
|
||||
|
||||
import javax.net.ssl.SSLContext
|
||||
import javax.net.ssl.SSLPeerUnverifiedException
|
||||
import javax.net.ssl.SSLServerSocket
|
||||
|
||||
@RunWith(JUnit4.class)
|
||||
class ConnectionLoadBalanceServerTest extends GroovyTestCase {
|
||||
private static final String HOSTNAME = "localhost"
|
||||
private static final int PORT = 54321
|
||||
private static final int NUM_THREADS = 1
|
||||
private static final int TIMEOUT_MS = 1000
|
||||
|
||||
private static TlsConfiguration tlsConfiguration
|
||||
private static SSLContext sslContext
|
||||
|
||||
private ConnectionLoadBalanceServer lbServer
|
||||
|
||||
@BeforeClass
|
||||
static void setUpOnce() throws Exception {
|
||||
tlsConfiguration = new TemporaryKeyStoreBuilder().build()
|
||||
sslContext = SslContextFactory.createSslContext(tlsConfiguration)
|
||||
}
|
||||
|
||||
@Before
|
||||
void setUp() {
|
||||
}
|
||||
|
||||
@After
|
||||
void tearDown() {
|
||||
if (lbServer) {
|
||||
lbServer.stop()
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testRequestPeerListShouldUseTLS() {
|
||||
// Arrange
|
||||
SSLContext sslContext = SslContextFactory.createSslContext(tlsConfiguration)
|
||||
|
||||
def mockLBP = [
|
||||
receiveFlowFiles: { Socket s, InputStream i, OutputStream o -> null }
|
||||
] as LoadBalanceProtocol
|
||||
def mockER = [:] as EventReporter
|
||||
|
||||
lbServer = new ConnectionLoadBalanceServer(HOSTNAME, PORT, sslContext, NUM_THREADS, mockLBP, mockER, TIMEOUT_MS)
|
||||
|
||||
// Act
|
||||
lbServer.start()
|
||||
|
||||
// Assert
|
||||
|
||||
// Assert that the actual socket is set correctly due to the override in the LB server
|
||||
SSLServerSocket socket = lbServer.serverSocket as SSLServerSocket
|
||||
assert socket.needClientAuth
|
||||
|
||||
// Clean up
|
||||
lbServer.stop()
|
||||
}
|
||||
|
||||
@Test
|
||||
void testShouldHandleSSLPeerUnverifiedException() {
|
||||
// Arrange
|
||||
final long testStartMillis = System.currentTimeMillis()
|
||||
final int CONNECTION_ATTEMPTS = 100
|
||||
// If this test takes longer than 3 seconds, it's likely because of external delays, which would invalidate the assertions
|
||||
final long MAX_TEST_DURATION_MILLIS = 3000
|
||||
final String peerDescription = "Test peer"
|
||||
final SSLPeerUnverifiedException e = new SSLPeerUnverifiedException("Test exception")
|
||||
|
||||
InputStream socketInputStream = new ByteArrayInputStream("This is the socket input stream".bytes)
|
||||
OutputStream socketOutputStream = new ByteArrayOutputStream()
|
||||
|
||||
Socket mockSocket = [
|
||||
getInputStream : { -> socketInputStream },
|
||||
getOutputStream: { -> socketOutputStream },
|
||||
] as Socket
|
||||
LoadBalanceProtocol mockLBProtocol = [
|
||||
receiveFlowFiles: { Socket s, InputStream i, OutputStream o -> null }
|
||||
] as LoadBalanceProtocol
|
||||
EventReporter mockER = [
|
||||
reportEvent: { Severity s, String c, String m -> }
|
||||
] as EventReporter
|
||||
|
||||
def output = [debug: 0, error: 0]
|
||||
|
||||
ConnectionLoadBalanceServer.CommunicateAction communicateAction = new ConnectionLoadBalanceServer.CommunicateAction(mockLBProtocol, mockSocket, mockER)
|
||||
|
||||
// Override the threshold to 100 ms
|
||||
communicateAction.EXCEPTION_THRESHOLD_MILLIS = 100
|
||||
|
||||
// Act
|
||||
CONNECTION_ATTEMPTS.times { int i ->
|
||||
boolean printedError = communicateAction.handleTlsError(peerDescription, e)
|
||||
if (printedError) {
|
||||
output.error++
|
||||
} else {
|
||||
output.debug++
|
||||
}
|
||||
sleep(10)
|
||||
}
|
||||
|
||||
// Only enforce if the test completed in a reasonable amount of time (i.e. external delays did not influence the timing)
|
||||
long testStopMillis = System.currentTimeMillis()
|
||||
long testDurationMillis = testStopMillis - testStartMillis
|
||||
if (testDurationMillis <= MAX_TEST_DURATION_MILLIS) {
|
||||
assert output.debug > output.error
|
||||
}
|
||||
|
||||
// Clean up
|
||||
communicateAction.stop()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,150 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.nifi.controller.queue.clustered.server;
|
||||
|
||||
import org.apache.nifi.events.EventReporter;
|
||||
import org.apache.nifi.remote.io.socket.NetworkUtils;
|
||||
import org.apache.nifi.security.util.SslContextFactory;
|
||||
import org.apache.nifi.security.util.TemporaryKeyStoreBuilder;
|
||||
import org.apache.nifi.security.util.TlsConfiguration;
|
||||
import org.apache.nifi.security.util.TlsException;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
|
||||
import javax.net.ssl.SSLContext;
|
||||
import javax.net.ssl.SSLException;
|
||||
import javax.net.ssl.SSLSocket;
|
||||
import javax.net.ssl.SSLSocketFactory;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.net.Socket;
|
||||
import java.time.Duration;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
|
||||
import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class ConnectionLoadBalanceServerTest {
|
||||
|
||||
static final String LOCALHOST = "127.0.0.1";
|
||||
|
||||
static final int SERVER_THREADS = 1;
|
||||
|
||||
static final int SOCKET_TIMEOUT_MILLIS = 30000;
|
||||
|
||||
static final Duration SOCKET_TIMEOUT = Duration.ofMillis(SOCKET_TIMEOUT_MILLIS);
|
||||
|
||||
static SSLContext sslContext;
|
||||
|
||||
@Mock
|
||||
EventReporter eventReporter;
|
||||
|
||||
@BeforeAll
|
||||
static void setSslContext() throws TlsException {
|
||||
final TlsConfiguration tlsConfiguration = new TemporaryKeyStoreBuilder().build();
|
||||
sslContext = SslContextFactory.createSslContext(tlsConfiguration);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testSslContextSocketHandshakeCompleted() throws IOException {
|
||||
final ConnectionLoadBalanceServer server = getServer(new SslHandshakeCompletedLoadBalanceProtocol());
|
||||
|
||||
try {
|
||||
server.start();
|
||||
|
||||
final SSLSocketFactory socketFactory = sslContext.getSocketFactory();
|
||||
try (final Socket socket = socketFactory.createSocket()) {
|
||||
final SSLSocket sslSocket = assertSocketConnected(socket, server.getPort());
|
||||
assertTimeoutPreemptively(SOCKET_TIMEOUT, sslSocket::startHandshake, "TLS handshake failed");
|
||||
}
|
||||
} finally {
|
||||
server.stop();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testHandshakeCompletedProtocolException() throws IOException {
|
||||
final ConnectionLoadBalanceServer server = getServer(new ReceiveFlowFilesSslExceptionLoadBalanceProtocol());
|
||||
|
||||
try {
|
||||
server.start();
|
||||
|
||||
final SSLSocketFactory socketFactory = sslContext.getSocketFactory();
|
||||
try (final Socket socket = socketFactory.createSocket()) {
|
||||
final SSLSocket sslSocket = assertSocketConnected(socket, server.getPort());
|
||||
assertTimeoutPreemptively(SOCKET_TIMEOUT, sslSocket::startHandshake, "TLS handshake failed");
|
||||
}
|
||||
} finally {
|
||||
server.stop();
|
||||
}
|
||||
}
|
||||
|
||||
private SSLSocket assertSocketConnected(final Socket socket, final int port) throws IOException {
|
||||
assertInstanceOf(SSLSocket.class, socket);
|
||||
|
||||
final SSLSocket sslSocket = (SSLSocket) socket;
|
||||
final InetSocketAddress socketAddress = new InetSocketAddress(LOCALHOST, port);
|
||||
sslSocket.connect(socketAddress, SOCKET_TIMEOUT_MILLIS);
|
||||
|
||||
assertTrue(sslSocket.isConnected());
|
||||
return sslSocket;
|
||||
}
|
||||
|
||||
private ConnectionLoadBalanceServer getServer(final LoadBalanceProtocol loadBalanceProtocol) {
|
||||
final int port = NetworkUtils.getAvailableTcpPort();
|
||||
return new ConnectionLoadBalanceServer(
|
||||
LOCALHOST,
|
||||
port,
|
||||
sslContext,
|
||||
SERVER_THREADS,
|
||||
loadBalanceProtocol,
|
||||
eventReporter,
|
||||
SOCKET_TIMEOUT_MILLIS
|
||||
);
|
||||
}
|
||||
|
||||
static class SslHandshakeCompletedLoadBalanceProtocol implements LoadBalanceProtocol {
|
||||
|
||||
@Override
|
||||
public void receiveFlowFiles(final Socket socket, final InputStream in, final OutputStream out) throws IOException {
|
||||
final SSLSocket sslSocket = (SSLSocket) socket;
|
||||
sslSocket.startHandshake();
|
||||
}
|
||||
}
|
||||
|
||||
static class ReceiveFlowFilesSslExceptionLoadBalanceProtocol implements LoadBalanceProtocol {
|
||||
final AtomicBoolean receiveFlowFilesInvoked = new AtomicBoolean();
|
||||
|
||||
@Override
|
||||
public void receiveFlowFiles(final Socket socket, final InputStream in, final OutputStream out) throws IOException {
|
||||
if (receiveFlowFilesInvoked.get()) {
|
||||
throw new SSLException(SSLException.class.getSimpleName());
|
||||
} else {
|
||||
final SSLSocket sslSocket = (SSLSocket) socket;
|
||||
sslSocket.startHandshake();
|
||||
}
|
||||
receiveFlowFilesInvoked.getAndSet(true);
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue