NIFI-11195 Refactored Identity Mapping to nifi-security-identity

- Moved StringUtils from nifi-properties to nifi-property-utils
- Moved Peer Identity methods from CertificateUtils to specific Site-to-Site classes

Signed-off-by: Joe Gresock <jgresock@gmail.com>
This closes #6977.
This commit is contained in:
exceptionfactory 2023-02-16 20:51:25 -06:00 committed by Joe Gresock
parent 87e61c50ee
commit 48689a2567
No known key found for this signature in database
GPG Key ID: 37F5B9B6E258C8B7
36 changed files with 380 additions and 1147 deletions

View File

@ -15,6 +15,12 @@
-->
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-commons</artifactId>
<version>2.0.0-SNAPSHOT</version>
</parent>
<artifactId>nifi-properties</artifactId>
<dependencies>
<dependency>
<groupId>org.apache.nifi</groupId>
@ -23,10 +29,4 @@
<scope>compile</scope>
</dependency>
</dependencies>
<parent>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-commons</artifactId>
<version>2.0.0-SNAPSHOT</version>
</parent>
<artifactId>nifi-properties</artifactId>
</project>

View File

@ -0,0 +1,35 @@
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<!--
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-commons</artifactId>
<version>2.0.0-SNAPSHOT</version>
</parent>
<artifactId>nifi-security-identity</artifactId>
<dependencies>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-properties</artifactId>
<version>2.0.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
</dependency>
</dependencies>
</project>

View File

@ -18,12 +18,10 @@ package org.apache.nifi.security.util;
import java.io.ByteArrayInputStream;
import java.math.BigInteger;
import java.net.Socket;
import java.security.KeyPair;
import java.security.NoSuchAlgorithmException;
import java.security.PublicKey;
import java.security.Security;
import java.security.cert.Certificate;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.CertificateParsingException;
@ -42,8 +40,6 @@ import javax.naming.InvalidNameException;
import javax.naming.ldap.LdapName;
import javax.naming.ldap.Rdn;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLSocket;
import org.apache.commons.lang3.StringUtils;
import org.bouncycastle.asn1.ASN1Encodable;
import org.bouncycastle.asn1.ASN1ObjectIdentifier;
@ -81,14 +77,8 @@ import org.slf4j.LoggerFactory;
public final class CertificateUtils {
private static final Logger logger = LoggerFactory.getLogger(CertificateUtils.class);
private static final String PEER_NOT_AUTHENTICATED_MSG = "peer not authenticated";
private static final Map<ASN1ObjectIdentifier, Integer> dnOrderMap = createDnOrderMap();
public static final String JAVA_8_MAX_SUPPORTED_TLS_PROTOCOL_VERSION = "TLSv1.2";
public static final String JAVA_11_MAX_SUPPORTED_TLS_PROTOCOL_VERSION = "TLSv1.3";
public static final String[] JAVA_8_SUPPORTED_TLS_PROTOCOL_VERSIONS = new String[]{JAVA_8_MAX_SUPPORTED_TLS_PROTOCOL_VERSION};
public static final String[] JAVA_11_SUPPORTED_TLS_PROTOCOL_VERSIONS = new String[]{JAVA_11_MAX_SUPPORTED_TLS_PROTOCOL_VERSION, JAVA_8_MAX_SUPPORTED_TLS_PROTOCOL_VERSION};
static {
Security.addProvider(new BouncyCastleProvider());
}
@ -206,136 +196,6 @@ public final class CertificateUtils {
return result;
}
/**
* Returns the DN extracted from the peer certificate (the server DN if run on the client; the client DN (if available) if run on the server).
* <p>
* If the client auth setting is WANT or NONE and a client certificate is not present, this method will return {@code null}.
* If the client auth is NEED, it will throw a {@link CertificateException}.
*
* @param socket the SSL Socket
* @return the extracted DN
* @throws CertificateException if there is a problem parsing the certificate
*/
public static String extractPeerDNFromSSLSocket(Socket socket) throws CertificateException {
String dn = null;
if (socket instanceof SSLSocket) {
final SSLSocket sslSocket = (SSLSocket) socket;
boolean clientMode = sslSocket.getUseClientMode();
logger.debug("SSL Socket in {} mode", clientMode ? "client" : "server");
ClientAuth clientAuth = getClientAuthStatus(sslSocket);
logger.debug("SSL Socket client auth status: {}", clientAuth);
if (clientMode) {
logger.debug("This socket is in client mode, so attempting to extract certificate from remote 'server' socket");
dn = extractPeerDNFromServerSSLSocket(sslSocket);
} else {
logger.debug("This socket is in server mode, so attempting to extract certificate from remote 'client' socket");
dn = extractPeerDNFromClientSSLSocket(sslSocket);
}
}
return dn;
}
/**
* Returns the DN extracted from the client certificate.
* <p>
* If the client auth setting is WANT or NONE and a certificate is not present (and {@code respectClientAuth} is {@code true}), this method will return {@code null}.
* If the client auth is NEED, it will throw a {@link CertificateException}.
*
* @param sslSocket the SSL Socket
* @return the extracted DN
* @throws CertificateException if there is a problem parsing the certificate
*/
private static String extractPeerDNFromClientSSLSocket(SSLSocket sslSocket) throws CertificateException {
String dn = null;
/** The clientAuth value can be "need", "want", or "none"
* A client must send client certificates for need, should for want, and will not for none.
* This method should throw an exception if none are provided for need, return null if none are provided for want, and return null (without checking) for none.
*/
ClientAuth clientAuth = getClientAuthStatus(sslSocket);
logger.debug("SSL Socket client auth status: {}", clientAuth);
if (clientAuth != ClientAuth.NONE) {
try {
final Certificate[] certChains = sslSocket.getSession().getPeerCertificates();
if (certChains != null && certChains.length > 0) {
X509Certificate x509Certificate = convertAbstractX509Certificate(certChains[0]);
dn = x509Certificate.getSubjectDN().getName().trim();
logger.debug("Extracted DN={} from client certificate", dn);
}
} catch (SSLPeerUnverifiedException e) {
if (e.getMessage().equals(PEER_NOT_AUTHENTICATED_MSG)) {
logger.error("The incoming request did not contain client certificates and thus the DN cannot" +
" be extracted. Check that the other endpoint is providing a complete client certificate chain");
}
if (clientAuth == ClientAuth.WANT) {
logger.warn("Suppressing missing client certificate exception because client auth is set to 'want'");
return null;
}
throw new CertificateException(e);
}
}
return dn;
}
/**
* Returns the DN extracted from the server certificate.
*
* @param socket the SSL Socket
* @return the extracted DN
* @throws CertificateException if there is a problem parsing the certificate
*/
private static String extractPeerDNFromServerSSLSocket(Socket socket) throws CertificateException {
String dn = null;
if (socket instanceof SSLSocket) {
final SSLSocket sslSocket = (SSLSocket) socket;
try {
final Certificate[] certChains = sslSocket.getSession().getPeerCertificates();
if (certChains != null && certChains.length > 0) {
X509Certificate x509Certificate = convertAbstractX509Certificate(certChains[0]);
dn = x509Certificate.getSubjectDN().getName().trim();
logger.debug("Extracted DN={} from server certificate", dn);
}
} catch (SSLPeerUnverifiedException e) {
if (e.getMessage().equals(PEER_NOT_AUTHENTICATED_MSG)) {
logger.error("The server did not present a certificate and thus the DN cannot" +
" be extracted. Check that the other endpoint is providing a complete certificate chain");
}
throw new CertificateException(e);
}
}
return dn;
}
private static ClientAuth getClientAuthStatus(SSLSocket sslSocket) {
return sslSocket.getNeedClientAuth() ? ClientAuth.REQUIRED : sslSocket.getWantClientAuth() ? ClientAuth.WANT : ClientAuth.NONE;
}
/**
* Accepts a legacy {@link javax.security.cert.X509Certificate} and returns an {@link X509Certificate}. The {@code javax.*} package certificate classes are for legacy compatibility and should
* not be used for new development.
*
* @param legacyCertificate the {@code javax.security.cert.X509Certificate}
* @return a new {@code java.security.cert.X509Certificate}
* @throws CertificateException if there is an error generating the new certificate
*/
@SuppressWarnings("deprecation")
public static X509Certificate convertLegacyX509Certificate(javax.security.cert.X509Certificate legacyCertificate) throws CertificateException {
if (legacyCertificate == null) {
throw new IllegalArgumentException("The X.509 certificate cannot be null");
}
try {
return formX509Certificate(legacyCertificate.getEncoded());
} catch (javax.security.cert.CertificateEncodingException e) {
throw new CertificateException(e);
}
}
/**
* Accepts an abstract {@link java.security.cert.Certificate} and returns an {@link X509Certificate}. Because {@code sslSocket.getSession().getPeerCertificates()} returns an array of the
* abstract certificates, they must be translated to X.509 to replace the functionality of {@code sslSocket.getSession().getPeerCertificateChain()}.

View File

@ -1,654 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.security.util
import org.bouncycastle.asn1.pkcs.PKCSObjectIdentifiers
import org.bouncycastle.asn1.x500.X500Name
import org.bouncycastle.asn1.x500.style.BCStyle
import org.bouncycastle.asn1.x500.style.IETFUtils
import org.bouncycastle.asn1.x509.Extension
import org.bouncycastle.asn1.x509.Extensions
import org.bouncycastle.asn1.x509.ExtensionsGenerator
import org.bouncycastle.asn1.x509.GeneralName
import org.bouncycastle.asn1.x509.GeneralNames
import org.bouncycastle.operator.OperatorCreationException
import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder
import org.bouncycastle.pkcs.jcajce.JcaPKCS10CertificationRequest
import org.bouncycastle.pkcs.jcajce.JcaPKCS10CertificationRequestBuilder
import org.bouncycastle.util.IPAddress
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.Test
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import javax.net.ssl.SSLException
import javax.net.ssl.SSLPeerUnverifiedException
import javax.net.ssl.SSLSession
import javax.net.ssl.SSLSocket
import java.security.InvalidKeyException
import java.security.KeyPair
import java.security.KeyPairGenerator
import java.security.NoSuchAlgorithmException
import java.security.NoSuchProviderException
import java.security.SignatureException
import java.security.cert.Certificate
import java.security.cert.CertificateException
import java.security.cert.X509Certificate
import java.util.concurrent.Callable
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ExecutionException
import java.util.concurrent.Executors
import java.util.concurrent.Future
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicBoolean
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertInstanceOf
import static org.junit.jupiter.api.Assertions.assertNull
import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
class CertificateUtilsTest {
private static final Logger logger = LoggerFactory.getLogger(CertificateUtilsTest.class)
private static final int KEY_SIZE = 2048
private static final int DAYS_IN_YEAR = 365
private static final long YESTERDAY = System.currentTimeMillis() - 24 * 60 * 60 * 1000
private static final long ONE_YEAR_FROM_NOW = System.currentTimeMillis() + 365 * 24 * 60 * 60 * 1000
private static final String SIGNATURE_ALGORITHM = "SHA256withRSA"
private static final String PROVIDER = "BC"
private static final String SUBJECT_DN = "CN=NiFi Test Server,OU=Security,O=Apache,ST=CA,C=US"
private static final String SUBJECT_DN_LEGACY_EMAIL_ATTR_RFC2985 = "CN=NiFi Test Server/emailAddress=test@apache.org,OU=Security,O=Apache,ST=CA,C=US"
private static final String ISSUER_DN = "CN=NiFi Test CA,OU=Security,O=Apache,ST=CA,C=US"
private static final List<String> SUBJECT_ALT_NAMES = ["127.0.0.1", "nifi.nifi.apache.org"]
@BeforeAll
static void setUpOnce() {
logger.metaClass.methodMissing = { String name, args ->
logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}")
}
}
/**
* Generates a public/private RSA keypair using the default key size.
*
* @return the keypair
* @throws java.security.NoSuchAlgorithmException if the RSA algorithm is not available
*/
private static KeyPair generateKeyPair() throws NoSuchAlgorithmException {
KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA")
keyPairGenerator.initialize(KEY_SIZE)
return keyPairGenerator.generateKeyPair()
}
/**
* Generates a signed certificate using an on-demand keypair.
*
* @param dn the DN
* @return the certificate
* @throws IOException* @throws NoSuchAlgorithmException
* @throws java.security.cert.CertificateException*
* @throws java.security.NoSuchProviderException
* @throws java.security.SignatureException
* @throws OperatorCreationException
*/
private
static X509Certificate generateCertificate(String dn) throws IOException, NoSuchAlgorithmException, CertificateException,
NoSuchProviderException, SignatureException, InvalidKeyException, OperatorCreationException {
KeyPair keyPair = generateKeyPair()
return CertificateUtils.generateSelfSignedX509Certificate(keyPair, dn, SIGNATURE_ALGORITHM, DAYS_IN_YEAR)
}
/**
* Generates a certificate signed by the issuer key.
*
* @param dn the subject DN
* @param issuerDn the issuer DN
* @param issuerKey the issuer private key
* @return the certificate
* @throws IOException
* @throws NoSuchAlgorithmException
* @throws CertificateException
* @throws NoSuchProviderException*
* @throws SignatureException* @throws InvalidKeyException
* @throws OperatorCreationException
*/
private
static X509Certificate generateIssuedCertificate(String dn, X509Certificate issuer, KeyPair issuerKey) throws IOException,
NoSuchAlgorithmException, CertificateException, NoSuchProviderException, SignatureException, InvalidKeyException, OperatorCreationException {
KeyPair keyPair = generateKeyPair()
return CertificateUtils.generateIssuedCertificate(dn, keyPair.getPublic(), issuer, issuerKey, SIGNATURE_ALGORITHM, DAYS_IN_YEAR)
}
private static X509Certificate[] generateCertificateChain(String dn = SUBJECT_DN, String issuerDn = ISSUER_DN) {
final KeyPair issuerKeyPair = generateKeyPair()
final X509Certificate issuerCertificate = CertificateUtils.generateSelfSignedX509Certificate(issuerKeyPair, issuerDn, SIGNATURE_ALGORITHM, DAYS_IN_YEAR)
final X509Certificate certificate = generateIssuedCertificate(dn, issuerCertificate, issuerKeyPair)
[certificate, issuerCertificate] as X509Certificate[]
}
@SuppressWarnings("deprecation")
private static javax.security.cert.X509Certificate generateLegacyCertificate(X509Certificate x509Certificate) {
return javax.security.cert.X509Certificate.getInstance(x509Certificate.getEncoded())
}
private static Certificate generateAbstractCertificate(X509Certificate x509Certificate) {
return x509Certificate as Certificate
}
private static Date inFuture(int days) {
return new Date(System.currentTimeMillis() + TimeUnit.DAYS.toMillis(days))
}
@Test
void testShouldConvertAbstractX509Certificate() {
// Arrange
final X509Certificate EXPECTED_NEW_CERTIFICATE = generateCertificate(SUBJECT_DN)
logger.info("Expected certificate: ${EXPECTED_NEW_CERTIFICATE.class.canonicalName} ${EXPECTED_NEW_CERTIFICATE.subjectDN.toString()} (${EXPECTED_NEW_CERTIFICATE.getSerialNumber()})")
// Form the abstract certificate
final Certificate ABSTRACT_CERTIFICATE = generateAbstractCertificate(EXPECTED_NEW_CERTIFICATE)
logger.info("Abstract certificate: ${ABSTRACT_CERTIFICATE.class.canonicalName} (?)")
// Act
X509Certificate convertedCertificate = CertificateUtils.convertAbstractX509Certificate(ABSTRACT_CERTIFICATE)
logger.info("Converted certificate: ${convertedCertificate.class.canonicalName} ${convertedCertificate.subjectDN.toString()} (${convertedCertificate.getSerialNumber()})")
// Assert
assertEquals(EXPECTED_NEW_CERTIFICATE, convertedCertificate)
}
@Test
void testShouldDetermineClientAuthStatusFromSocket() {
// Arrange
SSLSocket needSocket = [getNeedClientAuth: { -> true }] as SSLSocket
SSLSocket wantSocket = [getNeedClientAuth: { -> false }, getWantClientAuth: { -> true }] as SSLSocket
SSLSocket noneSocket = [getNeedClientAuth: { -> false }, getWantClientAuth: { -> false }] as SSLSocket
// Act
ClientAuth needClientAuthStatus = CertificateUtils.getClientAuthStatus(needSocket)
logger.info("Client auth (needSocket): ${needClientAuthStatus}")
ClientAuth wantClientAuthStatus = CertificateUtils.getClientAuthStatus(wantSocket)
logger.info("Client auth (wantSocket): ${wantClientAuthStatus}")
ClientAuth noneClientAuthStatus = CertificateUtils.getClientAuthStatus(noneSocket)
logger.info("Client auth (noneSocket): ${noneClientAuthStatus}")
// Assert
assertEquals(ClientAuth.REQUIRED, needClientAuthStatus)
assertEquals(ClientAuth.WANT, wantClientAuthStatus)
assertEquals(ClientAuth.NONE, noneClientAuthStatus)
}
@Test
void testShouldExtractClientCertificatesFromSSLServerSocketWithAnyClientAuth() {
final String EXPECTED_DN = "CN=ncm.nifi.apache.org,OU=Security,O=Apache,ST=CA,C=US"
Certificate[] certificateChain = generateCertificateChain(EXPECTED_DN)
logger.info("Expected DN: ${EXPECTED_DN}")
logger.info("Expected certificate chain: ${certificateChain.collect { (it as X509Certificate).getSubjectDN().name }.join(" issued by ")}")
SSLSession mockSession = [getPeerCertificates: { -> certificateChain }] as SSLSession
// This socket is in client mode, so the peer ("target") is a server
// Create mock sockets for each possible value of ClientAuth
SSLSocket mockNoneSocket = [
getUseClientMode : { -> true },
getNeedClientAuth: { -> false },
getWantClientAuth: { -> false },
getSession : { -> mockSession }
] as SSLSocket
SSLSocket mockNeedSocket = [
getUseClientMode : { -> true },
getNeedClientAuth: { -> true },
getWantClientAuth: { -> false },
getSession : { -> mockSession }
] as SSLSocket
SSLSocket mockWantSocket = [
getUseClientMode : { -> true },
getNeedClientAuth: { -> false },
getWantClientAuth: { -> true },
getSession : { -> mockSession }
] as SSLSocket
// Act
def resolvedServerDNs = [mockNeedSocket, mockWantSocket, mockNoneSocket].collect { SSLSocket mockSocket ->
logger.info("Running test with socket ClientAuth setting: ${CertificateUtils.getClientAuthStatus(mockSocket)}")
String serverDN = CertificateUtils.extractPeerDNFromSSLSocket(mockNoneSocket)
logger.info("Extracted server DN: ${serverDN}")
serverDN
}
// Assert
resolvedServerDNs.stream().forEach(serverDN -> assertTrue(CertificateUtils.compareDNs(serverDN, EXPECTED_DN)))
}
@Test
void testShouldNotExtractClientCertificatesFromSSLClientSocketWithClientAuthNone() {
// Arrange
// This socket is in server mode, so the peer ("target") is a client
SSLSocket mockSocket = [
getUseClientMode : { -> false },
getNeedClientAuth: { -> false },
getWantClientAuth: { -> false }
] as SSLSocket
// Act
String clientDN = CertificateUtils.extractPeerDNFromSSLSocket(mockSocket)
logger.info("Extracted client DN: ${clientDN}")
// Assert
assertNull(clientDN)
}
@Test
void testShouldExtractClientCertificatesFromSSLClientSocketWithClientAuthWant() {
// Arrange
final String EXPECTED_DN = "CN=client.nifi.apache.org,OU=Security,O=Apache,ST=CA,C=US"
Certificate[] certificateChain = generateCertificateChain(EXPECTED_DN)
logger.info("Expected DN: ${EXPECTED_DN}")
logger.info("Expected certificate chain: ${certificateChain.collect { (it as X509Certificate).getSubjectDN().name }.join(" issued by ")}")
SSLSession mockSession = [getPeerCertificates: { -> certificateChain }] as SSLSession
// This socket is in server mode, so the peer ("target") is a client
SSLSocket mockSocket = [
getUseClientMode : { -> false },
getNeedClientAuth: { -> false },
getWantClientAuth: { -> true },
getSession : { -> mockSession }
] as SSLSocket
// Act
String clientDN = CertificateUtils.extractPeerDNFromSSLSocket(mockSocket)
logger.info("Extracted client DN: ${clientDN}")
// Assert
assertTrue(CertificateUtils.compareDNs(clientDN, EXPECTED_DN))
}
@Test
void testShouldHandleFailureToExtractClientCertificatesFromSSLClientSocketWithClientAuthWant() {
// Arrange
SSLSession mockSession = [getPeerCertificates: { ->
throw new SSLPeerUnverifiedException("peer not authenticated")
}] as SSLSession
// This socket is in server mode, so the peer ("target") is a client
SSLSocket mockSocket = [
getUseClientMode : { -> false },
getNeedClientAuth: { -> false },
getWantClientAuth: { -> true },
getSession : { -> mockSession }
] as SSLSocket
// Act
String clientDN = CertificateUtils.extractPeerDNFromSSLSocket(mockSocket)
logger.info("Extracted client DN: ${clientDN}")
// Assert
assertTrue(CertificateUtils.compareDNs(clientDN, null))
}
@Test
void testShouldExtractClientCertificatesFromSSLClientSocketWithClientAuthNeed() {
// Arrange
final String EXPECTED_DN = "CN=client.nifi.apache.org,OU=Security,O=Apache,ST=CA,C=US"
Certificate[] certificateChain = generateCertificateChain(EXPECTED_DN)
logger.info("Expected DN: ${EXPECTED_DN}")
logger.info("Expected certificate chain: ${certificateChain.collect { (it as X509Certificate).getSubjectDN().name }.join(" issued by ")}")
SSLSession mockSession = [getPeerCertificates: { -> certificateChain }] as SSLSession
// This socket is in server mode, so the peer ("target") is a client
SSLSocket mockSocket = [
getUseClientMode : { -> false },
getNeedClientAuth: { -> true },
getWantClientAuth: { -> false },
getSession : { -> mockSession }
] as SSLSocket
// Act
String clientDN = CertificateUtils.extractPeerDNFromSSLSocket(mockSocket)
logger.info("Extracted client DN: ${clientDN}")
// Assert
assertTrue(CertificateUtils.compareDNs(clientDN, EXPECTED_DN))
}
@Test
void testShouldHandleFailureToExtractClientCertificatesFromSSLClientSocketWithClientAuthNeed() {
// Arrange
SSLSession mockSession = [getPeerCertificates: { ->
throw new SSLPeerUnverifiedException("peer not authenticated")
}] as SSLSession
// This socket is in server mode, so the peer ("target") is a client
SSLSocket mockSocket = [
getUseClientMode : { -> false },
getNeedClientAuth: { -> true },
getWantClientAuth: { -> false },
getSession : { -> mockSession }
] as SSLSocket
// Act
CertificateException ce = assertThrows(CertificateException.class,
() -> CertificateUtils.extractPeerDNFromSSLSocket(mockSocket))
// Assert
assertTrue(ce.getMessage().contains("peer not authenticated"))
}
@Test
void testShouldCompareDNs() {
// Arrange
final String DN_1_ORDERED = "CN=test1.nifi.apache.org, OU=Apache NiFi, O=Apache, ST=California, C=US"
logger.info("DN 1 Ordered : ${DN_1_ORDERED}")
final String DN_1_REVERSED = DN_1_ORDERED.split(", ").reverse().join(", ")
logger.info("DN 1 Reversed: ${DN_1_REVERSED}")
final String DN_2_ORDERED = "CN=test2.nifi.apache.org, OU=Apache NiFi, O=Apache, ST=California, C=US"
logger.info("DN 2 Ordered : ${DN_2_ORDERED}")
final String DN_2_REVERSED = DN_2_ORDERED.split(", ").reverse().join(", ")
logger.info("DN 2 Reversed: ${DN_2_REVERSED}")
// Act
// True
boolean dn1MatchesSelf = CertificateUtils.compareDNs(DN_1_ORDERED, DN_1_ORDERED)
logger.matches("DN 1, DN 1: ${dn1MatchesSelf}")
boolean dn1MatchesReversed = CertificateUtils.compareDNs(DN_1_ORDERED, DN_1_REVERSED)
logger.matches("DN 1, DN 1 (R): ${dn1MatchesReversed}")
boolean emptyMatchesEmpty = CertificateUtils.compareDNs("", "")
logger.matches("empty, empty: ${emptyMatchesEmpty}")
boolean nullMatchesNull = CertificateUtils.compareDNs(null, null)
logger.matches("null, null: ${nullMatchesNull}")
// False
boolean dn1MatchesDn2 = CertificateUtils.compareDNs(DN_1_ORDERED, DN_2_ORDERED)
logger.matches("DN 1, DN 2: ${dn1MatchesDn2}")
boolean dn1MatchesDn2Reversed = CertificateUtils.compareDNs(DN_1_ORDERED, DN_2_REVERSED)
logger.matches("DN 1, DN 2 (R): ${dn1MatchesDn2Reversed}")
boolean dn1MatchesEmpty = CertificateUtils.compareDNs(DN_1_ORDERED, "")
logger.matches("DN 1, empty: ${dn1MatchesEmpty}")
// Assert
assertTrue(dn1MatchesReversed)
assertTrue(emptyMatchesEmpty)
assertTrue(nullMatchesNull)
assertFalse(dn1MatchesDn2)
assertFalse(dn1MatchesDn2Reversed)
assertFalse(dn1MatchesEmpty)
}
@Test
void testGetCommonName(){
String dn1 = "CN=testDN,O=testOrg"
String dn2 = "O=testDN,O=testOrg"
assertEquals("testDN", CertificateUtils.getCommonName(dn1))
assertNull(CertificateUtils.getCommonName(dn2))
}
@Test
void testShouldGenerateSelfSignedCert() throws Exception {
String dn = "CN=testDN,O=testOrg"
int days = 365
X509Certificate x509Certificate = CertificateUtils.generateSelfSignedX509Certificate(generateKeyPair(), dn, SIGNATURE_ALGORITHM, days)
Date notAfter = x509Certificate.getNotAfter()
assertTrue(notAfter.after(inFuture(days - 1)))
assertTrue(notAfter.before(inFuture(days + 1)))
Date notBefore = x509Certificate.getNotBefore()
assertTrue(notBefore.after(inFuture(-1)))
assertTrue(notBefore.before(inFuture(1)))
assertEquals(dn, x509Certificate.getIssuerX500Principal().getName())
assertEquals(SIGNATURE_ALGORITHM.toUpperCase(), x509Certificate.getSigAlgName().toUpperCase())
assertEquals("RSA", x509Certificate.getPublicKey().getAlgorithm())
assertEquals(1, x509Certificate.getSubjectAlternativeNames().size())
GeneralName gn = x509Certificate.getSubjectAlternativeNames().iterator().next()
assertEquals(GeneralName.dNSName, gn.getTagNo())
assertEquals("testDN", gn.getName().toString())
x509Certificate.checkValidity()
}
@Test
void testIssueCert() throws Exception {
int days = 365
KeyPair issuerKeyPair = generateKeyPair()
X509Certificate issuer = CertificateUtils.generateSelfSignedX509Certificate(issuerKeyPair, "CN=testCa,O=testOrg", SIGNATURE_ALGORITHM, days)
String dn = "CN=testIssued, O=testOrg"
KeyPair keyPair = generateKeyPair()
X509Certificate x509Certificate = CertificateUtils.generateIssuedCertificate(dn, keyPair.getPublic(), issuer, issuerKeyPair, SIGNATURE_ALGORITHM, days)
assertEquals(dn, x509Certificate.getSubjectX500Principal().toString())
assertEquals(issuer.getSubjectX500Principal().toString(), x509Certificate.getIssuerX500Principal().toString())
assertEquals(keyPair.getPublic(), x509Certificate.getPublicKey())
Date notAfter = x509Certificate.getNotAfter()
assertTrue(notAfter.after(inFuture(days - 1)))
assertTrue(notAfter.before(inFuture(days + 1)))
Date notBefore = x509Certificate.getNotBefore()
assertTrue(notBefore.after(inFuture(-1)))
assertTrue(notBefore.before(inFuture(1)))
assertEquals(SIGNATURE_ALGORITHM.toUpperCase(), x509Certificate.getSigAlgName().toUpperCase())
assertEquals("RSA", x509Certificate.getPublicKey().getAlgorithm())
x509Certificate.verify(issuerKeyPair.getPublic())
}
@Test
void reorderShouldPutElementsInCorrectOrder() {
String cn = "CN=testcn"
String l = "L=testl"
String st = "ST=testst"
String o = "O=testo"
String ou = "OU=testou"
String c = "C=testc"
String street = "STREET=teststreet"
String dc = "DC=testdc"
String uid = "UID=testuid"
String surname = "SURNAME=testsurname"
String initials = "INITIALS=testinitials"
String givenName = "GIVENNAME=testgivenname"
assertEquals("$cn,$l,$st,$o,$ou,$c,$street,$dc,$uid,$surname,$givenName,$initials".toString(),
CertificateUtils.reorderDn("$surname,$st,$o,$initials,$givenName,$uid,$street,$c,$cn,$ou,$l,$dc"))
}
@Test
void testUniqueSerialNumbers() {
def running = new AtomicBoolean(true)
def executorService = Executors.newCachedThreadPool()
def serialNumbers = Collections.newSetFromMap(new ConcurrentHashMap())
try {
def futures = new ArrayList<Future>()
for (int i = 0; i < 8; i++) {
futures.add(executorService.submit(new Callable<Integer>() {
@Override
Integer call() throws Exception {
int count = 0
while (running.get()) {
def before = System.currentTimeMillis()
def serialNumber = CertificateUtils.getUniqueSerialNumber()
def after = System.currentTimeMillis()
def serialNumberMillis = serialNumber.shiftRight(32)
assertTrue(serialNumberMillis >= before)
assertTrue(serialNumberMillis <= after)
assertTrue(serialNumbers.add(serialNumber))
count++
}
return count
}
}))
}
Thread.sleep(1000)
running.set(false)
def totalRuns = 0
for (int i = 0; i < futures.size(); i++) {
try {
def numTimes = futures.get(i).get()
logger.info("future $i executed $numTimes times")
totalRuns += numTimes
} catch (ExecutionException e) {
throw e.getCause()
}
}
logger.info("Generated ${serialNumbers.size()} unique serial numbers")
assertEquals(totalRuns, serialNumbers.size())
} finally {
executorService.shutdown()
}
}
@Test
void testShouldGenerateIssuedCertificateWithSans() {
// Arrange
final String SUBJECT_DN = "CN=localhost"
final List<String> SANS = ["127.0.0.1", "nifi.nifi.apache.org"]
logger.info("Creating a certificate with subject: ${SUBJECT_DN} and SAN: ${SANS}")
final KeyPair subjectKeyPair = generateKeyPair()
final KeyPair issuerKeyPair = generateKeyPair()
final X509Certificate issuerCertificate = CertificateUtils.generateSelfSignedX509Certificate(issuerKeyPair, ISSUER_DN, SIGNATURE_ALGORITHM, DAYS_IN_YEAR)
// Form the SANS into GeneralName instances and populate the container with the array
def gns = SANS.collect { String san ->
new GeneralName(GeneralName.dNSName, san)
}
def generalNames = new GeneralNames(gns as GeneralName[])
logger.info("Created GeneralNames object: ${generalNames.names*.toString()}")
// Form the Extensions object
ExtensionsGenerator extensionsGenerator = new ExtensionsGenerator()
extensionsGenerator.addExtension(Extension.subjectAlternativeName, false, generalNames)
Extensions extensions = extensionsGenerator.generate()
logger.info("Generated extensions object: ${extensions.oids()*.toString()}")
// Act
X509Certificate certificate = CertificateUtils.generateIssuedCertificate(SUBJECT_DN, subjectKeyPair.public, extensions, issuerCertificate, issuerKeyPair, SIGNATURE_ALGORITHM, DAYS_IN_YEAR)
logger.info("Issued certificate with subject: ${certificate.getSubjectDN().name} and SAN: ${certificate.getSubjectAlternativeNames().join(",")}")
// Assert
assertEquals(SUBJECT_DN, certificate.getSubjectDN().name)
assertEquals(SANS.size(), certificate.getSubjectAlternativeNames().size())
assertTrue(certificate.getSubjectAlternativeNames()*.last().containsAll(SANS))
}
@Test
void testShouldDetectTlsErrors() {
// Arrange
final String msg = "Test exception"
// SSLPeerUnverifiedException isn't specifically defined in the method, but is a subclass of SSLException so it should be caught
List<Throwable> directErrors = [new TlsException(msg), new SSLPeerUnverifiedException(msg), new CertificateException(msg), new SSLException(msg)]
List<Throwable> causedErrors = directErrors.collect { Throwable cause -> new Exception(msg, cause) } + [
new Exception(msg,
new Exception("Nested $msg",
new Exception("Double nested $msg",
new TlsException("Triple nested $msg"))))]
List<Throwable> unrelatedErrors = [new Exception(msg), new IllegalArgumentException(msg), new NullPointerException(msg)]
// Act
def directResults = directErrors.collect { Throwable e -> CertificateUtils.isTlsError(e) }
def causedResults = causedErrors.collect { Throwable e -> CertificateUtils.isTlsError(e) }
def unrelatedResults = unrelatedErrors.collect { Throwable e -> CertificateUtils.isTlsError(e) }
logger.info("Direct results: ${directResults}")
logger.info("Caused results: ${causedResults}")
logger.info("Unrelated results: ${unrelatedResults}")
// Assert
assertTrue(directResults.every())
assertTrue(causedResults.every())
assertFalse(unrelatedResults.any())
}
@Test
void testGetExtensionsFromCSR() {
// Arrange
KeyPairGenerator generator = KeyPairGenerator.getInstance("RSA")
KeyPair keyPair = generator.generateKeyPair()
Extensions sanExtensions = createDomainAlternativeNamesExtensions(SUBJECT_ALT_NAMES, SUBJECT_DN)
JcaPKCS10CertificationRequestBuilder jcaPKCS10CertificationRequestBuilder = new JcaPKCS10CertificationRequestBuilder(new X500Name(SUBJECT_DN), keyPair.getPublic())
jcaPKCS10CertificationRequestBuilder.addAttribute(PKCSObjectIdentifiers.pkcs_9_at_extensionRequest, sanExtensions)
JcaContentSignerBuilder jcaContentSignerBuilder = new JcaContentSignerBuilder("SHA256WITHRSA")
JcaPKCS10CertificationRequest jcaPKCS10CertificationRequest = new JcaPKCS10CertificationRequest(jcaPKCS10CertificationRequestBuilder.build(jcaContentSignerBuilder.build(keyPair.getPrivate())))
// Act
Extensions extensions = CertificateUtils.getExtensionsFromCSR(jcaPKCS10CertificationRequest)
// Assert
assert(extensions.equivalent(sanExtensions))
}
@Test
void testExtractUserNameFromDN() {
String expected = "NiFi Test Server"
assertEquals(CertificateUtils.extractUsername(SUBJECT_DN), expected)
assertEquals(CertificateUtils.extractUsername(SUBJECT_DN_LEGACY_EMAIL_ATTR_RFC2985), expected)
}
// Using this directly from tls-toolkit results in a dependency loop, so it's added here for testing purposes.
private static Extensions createDomainAlternativeNamesExtensions(List<String> domainAlternativeNames, String requestedDn) throws IOException {
List<GeneralName> namesList = new ArrayList<>()
try {
final String cn = IETFUtils.valueToString(new X500Name(requestedDn).getRDNs(BCStyle.CN)[0].getFirst().getValue())
namesList.add(new GeneralName(GeneralName.dNSName, cn))
} catch (Exception e) {
throw new IOException("Failed to extract CN from request DN: " + requestedDn, e)
}
if (domainAlternativeNames != null) {
for (String alternativeName : domainAlternativeNames) {
namesList.add(new GeneralName(IPAddress.isValid(alternativeName) ? GeneralName.iPAddress : GeneralName.dNSName, alternativeName))
}
}
GeneralNames subjectAltNames = new GeneralNames(namesList.toArray([] as GeneralName[]))
ExtensionsGenerator extGen = new ExtensionsGenerator()
extGen.addExtension(Extension.subjectAlternativeName, false, subjectAltNames)
return extGen.generate()
}
}

View File

@ -46,16 +46,6 @@
<artifactId>nifi-utils</artifactId>
<version>2.0.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-security-utils</artifactId>
<version>2.0.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-security-utils-api</artifactId>
<version>2.0.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-security-ssl</artifactId>

View File

@ -26,12 +26,12 @@ import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.URI;
import java.security.cert.CertificateException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
@ -64,7 +64,6 @@ import org.apache.nifi.remote.io.socket.SocketCommunicationsSession;
import org.apache.nifi.remote.protocol.CommunicationsSession;
import org.apache.nifi.remote.protocol.SiteToSiteTransportProtocol;
import org.apache.nifi.remote.protocol.socket.SocketClientProtocol;
import org.apache.nifi.security.util.CertificateUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -72,6 +71,8 @@ public class EndpointConnectionPool implements PeerStatusProvider {
private static final Logger logger = LoggerFactory.getLogger(EndpointConnectionPool.class);
private static final SocketPeerIdentityProvider socketPeerIdentityProvider = new StandardSocketPeerIdentityProvider();
private final ConcurrentMap<PeerDescription, BlockingQueue<EndpointConnection>> connectionQueueMap = new ConcurrentHashMap<>();
private final Set<EndpointConnection> activeConnections = Collections.synchronizedSet(new HashSet<>());
@ -449,11 +450,12 @@ public class EndpointConnectionPool implements PeerStatusProvider {
socket.setSoTimeout(commsTimeout);
commsSession = new SocketCommunicationsSession(socket);
try {
final String dn = CertificateUtils.extractPeerDNFromSSLSocket(socket);
commsSession.setUserDn(dn);
} catch (final CertificateException ex) {
throw new IOException(ex);
final Optional<String> peerIdentity = socketPeerIdentityProvider.getPeerIdentity(socket);
if (peerIdentity.isPresent()) {
final String userDn = peerIdentity.get();
commsSession.setUserDn(userDn);
} else {
throw new IOException(String.format("Site-to-Site Peer [%s] Identity not found", socket.getRemoteSocketAddress()));
}
} else {

View File

@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.remote.client.socket;
import java.net.Socket;
import java.util.Optional;
/**
* Abstraction for reading identity information from socket connections
*/
public interface SocketPeerIdentityProvider {
/**
* Get Peer Identity from Socket
*
* @param socket Socket
* @return Peer Identity or empty when not found
*/
Optional<String> getPeerIdentity(Socket socket);
}

View File

@ -0,0 +1,74 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.remote.client.socket;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocket;
import java.net.Socket;
import java.security.Principal;
import java.security.cert.Certificate;
import java.security.cert.X509Certificate;
import java.util.Optional;
/**
* Standard implementation attempts to read X.509 certificates from an SSLSocket
*/
public class StandardSocketPeerIdentityProvider implements SocketPeerIdentityProvider {
private static final Logger logger = LoggerFactory.getLogger(StandardSocketPeerIdentityProvider.class);
@Override
public Optional<String> getPeerIdentity(final Socket socket) {
final Optional<String> peerIdentity;
if (socket instanceof SSLSocket) {
final SSLSocket sslSocket = (SSLSocket) socket;
final SSLSession sslSession = sslSocket.getSession();
peerIdentity = getPeerIdentity(sslSession);
} else {
peerIdentity = Optional.empty();
}
return peerIdentity;
}
private Optional<String> getPeerIdentity(final SSLSession sslSession) {
String peerIdentity = null;
final String peerHost = sslSession.getPeerHost();
final int peerPort = sslSession.getPeerPort();
try {
final Certificate[] peerCertificates = sslSession.getPeerCertificates();
if (peerCertificates == null || peerCertificates.length == 0) {
logger.warn("Peer Identity not found: Peer Certificates not provided [{}:{}]", peerHost, peerPort);
} else {
final X509Certificate peerCertificate = (X509Certificate) peerCertificates[0];
final Principal subjectDistinguishedName = peerCertificate.getSubjectDN();
peerIdentity = subjectDistinguishedName.getName();
}
} catch (final SSLPeerUnverifiedException e) {
logger.warn("Peer Identity not found: Peer Unverified [{}:{}]", peerHost, peerPort);
logger.debug("TLS Protocol [{}] Peer Unverified [{}:{}]", sslSession.getProtocol(), peerHost, peerPort, e);
}
return Optional.ofNullable(peerIdentity);
}
}

View File

@ -47,7 +47,6 @@ import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.nio.charset.StandardCharsets;
import java.security.cert.Certificate;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.Arrays;
import java.util.Collection;
@ -128,7 +127,6 @@ import org.apache.nifi.remote.protocol.ResponseCode;
import org.apache.nifi.remote.protocol.http.HttpHeaders;
import org.apache.nifi.remote.protocol.http.HttpProxy;
import org.apache.nifi.reporting.Severity;
import org.apache.nifi.security.util.CertificateUtils;
import org.apache.nifi.stream.io.StreamUtils;
import org.apache.nifi.web.api.dto.ControllerDTO;
import org.apache.nifi.web.api.dto.remote.PeerDTO;
@ -319,9 +317,9 @@ public class SiteToSiteRestApiClient implements Closeable {
}
try {
final X509Certificate cert = CertificateUtils.convertAbstractX509Certificate(certChain[0]);
final X509Certificate cert = (X509Certificate) certChain[0];
trustedPeerDn = cert.getSubjectDN().getName().trim();
} catch (final CertificateException e) {
} catch (final RuntimeException e) {
final String msg = "Could not extract subject DN from SSL session peer certificate";
logger.warn(msg);
eventReporter.reportEvent(Severity.WARNING, EVENT_CATEGORY, msg);

View File

@ -16,21 +16,16 @@
*/
package org.apache.nifi.remote.client
import org.apache.nifi.remote.PeerDescription
import org.apache.nifi.remote.PeerStatus
import org.apache.nifi.remote.TransferDirection
import org.apache.nifi.remote.protocol.SiteToSiteTransportProtocol
import org.apache.nifi.remote.util.PeerStatusCache
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.security.Security
import java.util.concurrent.ArrayBlockingQueue
import static org.junit.jupiter.api.Assertions.assertEquals
@ -52,16 +47,6 @@ class PeerSelectorTest {
private static mockPSP
private static mockPP
@BeforeAll
static void setUpOnce() throws Exception {
Security.addProvider(new BouncyCastleProvider())
logger.metaClass.methodMissing = { String name, args ->
logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}")
}
}
@BeforeEach
void setUp() {
// Mock collaborators
@ -69,11 +54,6 @@ class PeerSelectorTest {
mockPP = mockPeerPersistence()
}
@AfterEach
void tearDown() {
}
private static String buildRemoteInstanceUris(List<String> nodes = DEFAULT_NODES) {
String remoteInstanceUris = "http://" + nodes.join(":8443/nifi-api,http://") + ":8443/nifi-api";
remoteInstanceUris
@ -206,7 +186,6 @@ class PeerSelectorTest {
new PeerStatusCache(peerStatuses, System.currentTimeMillis(), remoteInstanceUris, SiteToSiteTransportProtocol.HTTP)
},
save : { PeerStatusCache psc ->
logger.mock("Persisting PeerStatusCache: ${psc}")
}] as PeerPersistence
}
@ -985,8 +964,6 @@ class PeerSelectorTest {
bootstrapDescription
},
fetchRemotePeerStatuses : { PeerDescription pd ->
// Depending on the scenario, return given peer statuses
logger.mock("Scenario ${currentAttempt} fetchRemotePeerStatus for ${pd}")
switch (currentAttempt) {
case 1:
return [bootstrapStatus, node2Status] as Set<PeerStatus>

View File

@ -22,10 +22,8 @@ import org.apache.nifi.events.EventReporter;
import org.apache.nifi.remote.Peer;
import org.apache.nifi.remote.Transaction;
import org.apache.nifi.remote.TransferDirection;
import org.apache.nifi.remote.client.KeystoreType;
import org.apache.nifi.remote.client.SiteToSiteClient;
import org.apache.nifi.remote.codec.StandardFlowFileCodec;
import org.apache.nifi.remote.exception.HandshakeException;
import org.apache.nifi.remote.io.CompressionInputStream;
import org.apache.nifi.remote.io.CompressionOutputStream;
import org.apache.nifi.remote.protocol.DataPacket;
@ -34,8 +32,6 @@ import org.apache.nifi.remote.protocol.SiteToSiteTransportProtocol;
import org.apache.nifi.remote.protocol.http.HttpHeaders;
import org.apache.nifi.remote.protocol.http.HttpProxy;
import org.apache.nifi.remote.util.StandardDataPacket;
import org.apache.nifi.security.util.TemporaryKeyStoreBuilder;
import org.apache.nifi.security.util.TlsConfiguration;
import org.apache.nifi.stream.io.StreamUtils;
import org.apache.nifi.web.api.dto.ControllerDTO;
import org.apache.nifi.web.api.dto.PortDTO;
@ -45,16 +41,11 @@ import org.apache.nifi.web.api.entity.PeersEntity;
import org.apache.nifi.web.api.entity.TransactionResultEntity;
import org.eclipse.jetty.server.Connector;
import org.eclipse.jetty.server.Handler;
import org.eclipse.jetty.server.HttpConfiguration;
import org.eclipse.jetty.server.HttpConnectionFactory;
import org.eclipse.jetty.server.SecureRequestCustomizer;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.server.SslConnectionFactory;
import org.eclipse.jetty.server.handler.ContextHandlerCollection;
import org.eclipse.jetty.servlet.ServletContextHandler;
import org.eclipse.jetty.servlet.ServletHandler;
import org.eclipse.jetty.util.ssl.SslContextFactory;
import org.eclipse.jetty.util.thread.QueuedThreadPool;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
@ -117,7 +108,6 @@ public class TestHttpClient {
private static Server server;
private static ServerConnector httpConnector;
private static ServerConnector sslConnector;
private static CountDownLatch testCaseFinished;
private static HttpProxyServer proxyServer;
@ -126,11 +116,8 @@ public class TestHttpClient {
private static Set<PortDTO> inputPorts;
private static Set<PortDTO> outputPorts;
private static Set<PeerDTO> peers;
private static Set<PeerDTO> peersSecure;
private static String serverChecksum;
private static TlsConfiguration tlsConfiguration;
private static final int INITIAL_TRANSACTIONS = 0;
private static final AtomicInteger outputExtendTransactions = new AtomicInteger(INITIAL_TRANSACTIONS);
@ -141,16 +128,13 @@ public class TestHttpClient {
public static class SiteInfoServlet extends HttpServlet {
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException {
final ControllerDTO controller = new ControllerDTO();
if (req.getLocalPort() == httpConnector.getLocalPort()) {
controller.setRemoteSiteHttpListeningPort(httpConnector.getLocalPort());
controller.setSiteToSiteSecure(false);
} else {
controller.setRemoteSiteHttpListeningPort(sslConnector.getLocalPort());
controller.setSiteToSiteSecure(true);
}
controller.setId("remote-controller-id");
@ -175,7 +159,7 @@ public class TestHttpClient {
public static class WrongSiteInfoServlet extends HttpServlet {
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException {
// This response simulates when a Site-to-Site is given a URL which has wrong path.
respondWithText(resp, "<p class=\"message-pane-content\">You may have mistyped...</p>", 200);
}
@ -184,16 +168,13 @@ public class TestHttpClient {
public static class PeersServlet extends HttpServlet {
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException {
final PeersEntity peersEntity = new PeersEntity();
if (req.getLocalPort() == httpConnector.getLocalPort()) {
assertNotNull(peers, "Test case should set <peers> depending on the test scenario.");
peersEntity.setPeers(peers);
} else {
assertNotNull(peersSecure, "Test case should set <peersSecure> depending on the test scenario.");
peersEntity.setPeers(peersSecure);
}
respondWithJson(resp, peersEntity);
@ -368,7 +349,7 @@ public class TestHttpClient {
}
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException {
final int reqProtocolVersion = getReqProtocolVersion(req);
@ -477,24 +458,10 @@ public class TestHttpClient {
final ServletHandler wrongPathServletHandler = new ServletHandler();
wrongPathContextHandler.insertHandler(wrongPathServletHandler);
final SslContextFactory sslContextFactory = new SslContextFactory.Server();
setTlsConfiguration();
sslContextFactory.setKeyStorePath(tlsConfiguration.getKeystorePath());
sslContextFactory.setKeyStorePassword(tlsConfiguration.getKeystorePassword());
sslContextFactory.setKeyStoreType(tlsConfiguration.getKeystoreType().getType());
sslContextFactory.setProtocol(TlsConfiguration.getHighestCurrentSupportedTlsProtocolVersion());
httpConnector = new ServerConnector(server);
final HttpConfiguration https = new HttpConfiguration();
https.addCustomizer(new SecureRequestCustomizer());
sslConnector = new ServerConnector(server,
new SslConnectionFactory(sslContextFactory, "http/1.1"),
new HttpConnectionFactory(https));
logger.info("SSL Connector: " + sslConnector.dump());
server.setConnectors(new Connector[] { httpConnector, sslConnector });
server.setConnectors(new Connector[] { httpConnector });
wrongPathServletHandler.addServletWithMapping(WrongSiteInfoServlet.class, "/site-to-site");
@ -528,8 +495,6 @@ public class TestHttpClient {
server.start();
logger.info("Starting server on port {} for HTTP, and {} for HTTPS", httpConnector.getLocalPort(), sslConnector.getLocalPort());
startProxyServer();
startProxyServerWithAuth();
}
@ -634,15 +599,6 @@ public class TestHttpClient {
peers = new HashSet<>();
peers.add(peer);
final PeerDTO peerSecure = new PeerDTO();
peerSecure.setHostname("localhost");
peerSecure.setPort(sslConnector.getLocalPort());
peerSecure.setFlowFileCount(10);
peerSecure.setSecure(true);
peersSecure = new HashSet<>();
peersSecure.add(peerSecure);
inputPorts = new HashSet<>();
final PortDTO runningInputPort = new PortDTO();
@ -711,18 +667,6 @@ public class TestHttpClient {
;
}
private SiteToSiteClient.Builder getDefaultBuilderHTTPS() {
return new SiteToSiteClient.Builder().transportProtocol(SiteToSiteTransportProtocol.HTTP)
.url("https://localhost:" + sslConnector.getLocalPort() + "/nifi")
.timeout(3, TimeUnit.MINUTES)
.keystoreFilename(tlsConfiguration.getKeystorePath())
.keystorePass(tlsConfiguration.getKeystorePassword())
.keystoreType(KeystoreType.valueOf(tlsConfiguration.getKeystoreType().getType()))
.truststoreFilename(tlsConfiguration.getTruststorePath())
.truststorePass(tlsConfiguration.getTruststorePassword())
.truststoreType(KeystoreType.valueOf(tlsConfiguration.getTruststoreType().getType()));
}
private static void consumeDataPacket(DataPacket packet) throws IOException {
final ByteArrayOutputStream bos = new ByteArrayOutputStream();
StreamUtils.copy(packet.getData(), bos);
@ -893,31 +837,6 @@ public class TestHttpClient {
}
@Test
public void testSendAccessDeniedHTTPS() throws Exception {
try (
final SiteToSiteClient client = getDefaultBuilderHTTPS()
.portName("input-access-denied")
.build()
) {
assertThrows(HandshakeException.class, () -> client.createTransaction(TransferDirection.SEND));
}
}
@Test
public void testSendSuccessHTTPS() throws Exception {
try (
final SiteToSiteClient client = getDefaultBuilderHTTPS()
.portName("input-running")
.build()
) {
testSend(client);
}
}
private interface SendData {
void apply(final Transaction transaction) throws IOException;
}
@ -1013,47 +932,6 @@ public class TestHttpClient {
}
@Test
public void testSendLargeFileHTTPS() throws Exception {
try (
SiteToSiteClient client = getDefaultBuilderHTTPS()
.portName("input-running")
.build()
) {
testSendLargeFile(client);
}
}
@Test
public void testSendLargeFileHTTPSWithProxy() throws Exception {
try (
SiteToSiteClient client = getDefaultBuilderHTTPS()
.portName("input-running")
.httpProxy(new HttpProxy("localhost", proxyServer.getListenAddress().getPort(), null, null))
.build()
) {
testSendLargeFile(client);
}
}
@Test
public void testSendLargeFileHTTPSWithProxyAuth() throws Exception {
try (
SiteToSiteClient client = getDefaultBuilderHTTPS()
.portName("input-running")
.httpProxy(new HttpProxy("localhost", proxyServerWithAuth.getListenAddress().getPort(), PROXY_USER, PROXY_PASSWORD))
.build()
) {
testSendLargeFile(client);
}
}
@Test
public void testSendSuccessCompressed() throws Exception {
@ -1264,44 +1142,6 @@ public class TestHttpClient {
}
}
@Test
public void testReceiveSuccessHTTPS() throws Exception {
try (
SiteToSiteClient client = getDefaultBuilderHTTPS()
.portName("output-running")
.build()
) {
testReceive(client);
}
}
@Test
public void testReceiveSuccessHTTPSWithProxy() throws Exception {
try (
SiteToSiteClient client = getDefaultBuilderHTTPS()
.portName("output-running")
.httpProxy(new HttpProxy("localhost", proxyServer.getListenAddress().getPort(), null, null))
.build()
) {
testReceive(client);
}
}
@Test
public void testReceiveSuccessHTTPSWithProxyAuth() throws Exception {
try (
SiteToSiteClient client = getDefaultBuilderHTTPS()
.portName("output-running")
.httpProxy(new HttpProxy("localhost", proxyServerWithAuth.getListenAddress().getPort(), PROXY_USER, PROXY_PASSWORD))
.build()
) {
testReceive(client);
}
}
@Test
public void testReceiveSuccessCompressed() throws Exception {
@ -1375,8 +1215,4 @@ public class TestHttpClient {
assertNotSame(INITIAL_TRANSACTIONS, outputExtendTransactions.get());
}
}
private static void setTlsConfiguration() {
tlsConfiguration = new TemporaryKeyStoreBuilder().trustStoreType(KeystoreType.JKS.name()).build();
}
}

View File

@ -0,0 +1,103 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.remote.client.socket;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocket;
import javax.security.auth.x500.X500Principal;
import java.io.IOException;
import java.net.Socket;
import java.security.cert.Certificate;
import java.security.cert.X509Certificate;
import java.util.Optional;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
class StandardSocketPeerIdentityProviderTest {
private static final String DISTINGUISHED_NAME = "CN=Common Name,OU=Organizational Unit,O=Organization";
@Mock
SSLSocket sslSocket;
@Mock
SSLSession sslSession;
@Mock
X509Certificate peerCertificate;
StandardSocketPeerIdentityProvider provider;
@BeforeEach
void setProvider() {
provider = new StandardSocketPeerIdentityProvider();
}
@Test
void testGetPeerIdentityStandardSocket() throws IOException {
try (Socket socket = new Socket()) {
final Optional<String> peerIdentity = provider.getPeerIdentity(socket);
assertFalse(peerIdentity.isPresent());
}
}
@Test
void testGetPeerIdentitySSLSocketPeerUnverifiedException() throws SSLPeerUnverifiedException {
when(sslSocket.getSession()).thenReturn(sslSession);
when(sslSession.getPeerCertificates()).thenThrow(new SSLPeerUnverifiedException(SSLPeerUnverifiedException.class.getSimpleName()));
final Optional<String> peerIdentity = provider.getPeerIdentity(sslSocket);
assertFalse(peerIdentity.isPresent());
}
@Test
void testGetPeerIdentitySSLSocketPeerCertificatesNotFound() throws SSLPeerUnverifiedException {
when(sslSocket.getSession()).thenReturn(sslSession);
when(sslSession.getPeerCertificates()).thenReturn(new Certificate[]{});
final Optional<String> peerIdentity = provider.getPeerIdentity(sslSocket);
assertFalse(peerIdentity.isPresent());
}
@Test
void testGetPeerIdentityFound() throws SSLPeerUnverifiedException {
when(sslSocket.getSession()).thenReturn(sslSession);
when(sslSession.getPeerCertificates()).thenReturn(new X509Certificate[]{peerCertificate});
final X500Principal subjectDistinguishedName = new X500Principal(DISTINGUISHED_NAME);
when(peerCertificate.getSubjectDN()).thenReturn(subjectDistinguishedName);
final Optional<String> peerIdentity = provider.getPeerIdentity(sslSocket);
assertTrue(peerIdentity.isPresent());
final String identity = peerIdentity.get();
assertEquals(DISTINGUISHED_NAME, identity);
}
}

View File

@ -39,6 +39,12 @@
<artifactId>nifi-api</artifactId>
<version>2.0.0-SNAPSHOT</version>
</dependency>
<!-- Included for StringUtils -->
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-property-utils</artifactId>
<version>2.0.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>jakarta.xml.bind</groupId>
<artifactId>jakarta.xml.bind-api</artifactId>

View File

@ -54,6 +54,7 @@
<module>nifi-repository-encryption</module>
<module>nifi-schema-utils</module>
<module>nifi-security-crypto-key</module>
<module>nifi-security-identity</module>
<module>nifi-security-kerberos-api</module>
<module>nifi-security-kerberos</module>
<module>nifi-security-kms</module>

View File

@ -136,6 +136,12 @@
<version>2.0.0-SNAPSHOT</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-security-utils</artifactId>
<version>2.0.0-SNAPSHOT</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.icegreen</groupId>
<artifactId>greenmail</artifactId>

View File

@ -32,6 +32,16 @@
<artifactId>nifi-api</artifactId>
<version>2.0.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-security-utils-api</artifactId>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-security-utils</artifactId>
<version>2.0.0-SNAPSHOT</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-mock</artifactId>

View File

@ -38,6 +38,10 @@
<artifactId>nifi-security-socket-ssl</artifactId>
<version>2.0.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-security-utils-api</artifactId>
</dependency>
<!-- Other modules using nifi-standard-record-utils are expected to have these APIs available, typically through a NAR dependency -->
<dependency>
<groupId>org.apache.nifi</groupId>

View File

@ -34,6 +34,10 @@
<artifactId>nifi-utils</artifactId>
<version>2.0.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-security-utils-api</artifactId>
</dependency>
<dependency>
<groupId>org.glassfish</groupId>
<artifactId>javax.json</artifactId>

View File

@ -45,7 +45,6 @@
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-utils</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>

View File

@ -100,7 +100,8 @@
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-security-utils</artifactId>
<artifactId>nifi-security-identity</artifactId>
<version>2.0.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>

View File

@ -77,6 +77,11 @@
<artifactId>nifi-security-kms</artifactId>
<version>2.0.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-security-identity</artifactId>
<version>2.0.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-repository-encryption</artifactId>

View File

@ -38,6 +38,15 @@
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-framework-core-api</artifactId>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-security-utils-api</artifactId>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-security-identity</artifactId>
<version>2.0.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-mock</artifactId>

View File

@ -26,6 +26,10 @@ import java.net.InetAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketTimeoutException;
import java.security.GeneralSecurityException;
import java.security.Principal;
import java.security.cert.Certificate;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
@ -33,7 +37,12 @@ import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLServerSocket;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocket;
import org.apache.nifi.groups.ProcessGroup;
import org.apache.nifi.remote.cluster.ClusterNodeInformation;
import org.apache.nifi.remote.cluster.NodeInformant;
@ -46,7 +55,6 @@ import org.apache.nifi.remote.io.socket.SocketCommunicationsSession;
import org.apache.nifi.remote.protocol.CommunicationsSession;
import org.apache.nifi.remote.protocol.RequestType;
import org.apache.nifi.remote.protocol.ServerProtocol;
import org.apache.nifi.security.util.CertificateUtils;
import org.apache.nifi.security.util.TlsConfiguration;
import org.apache.nifi.util.NiFiProperties;
import org.slf4j.Logger;
@ -160,7 +168,8 @@ public class SocketRemoteSiteListener implements RemoteSiteListener {
try {
if (secure) {
LOG.trace("{} Connection is secure", this);
dn = CertificateUtils.extractPeerDNFromSSLSocket(socket);
final SSLSocket sslSocket = (SSLSocket) socket;
dn = getPeerIdentity(sslSocket);
commsSession = new SocketCommunicationsSession(socket);
commsSession.setUserDn(dn);
@ -174,7 +183,7 @@ public class SocketRemoteSiteListener implements RemoteSiteListener {
// TODO: Add SocketProtocolListener#handleTlsError logic here
String msg = String.format("RemoteSiteListener Unable to accept connection from %s due to %s", socket, e.getLocalizedMessage());
// Suppress repeated TLS errors
if (CertificateUtils.isTlsError(e)) {
if (isTlsError(e)) {
boolean printedAsWarning = handleTlsError(msg);
// TODO: Move into handleTlsError and refactor shared behavior
@ -320,6 +329,32 @@ public class SocketRemoteSiteListener implements RemoteSiteListener {
listenerThread.start();
}
private boolean isTlsError(final Throwable e) {
final boolean tlsError;
if (e instanceof SSLException || e instanceof GeneralSecurityException) {
tlsError = true;
} else if (e.getCause() == null) {
tlsError = false;
} else {
tlsError = isTlsError(e.getCause());
}
return tlsError;
}
private String getPeerIdentity(final SSLSocket sslSocket) throws SSLPeerUnverifiedException {
final SSLSession sslSession = sslSocket.getSession();
final Certificate[] peerCertificates = sslSession.getPeerCertificates();
if (peerCertificates == null || peerCertificates.length == 0) {
throw new SSLPeerUnverifiedException(String.format("Peer [%s] certificates not found", sslSocket.getRemoteSocketAddress()));
}
final X509Certificate peerCertificate = (X509Certificate) peerCertificates[0];
final Principal subjectDistinguishedName = peerCertificate.getSubjectDN();
return subjectDistinguishedName.getName();
}
private boolean handleTlsError(String msg) {
if (tlsErrorRecentlySeen()) {
LOG.debug(msg);
@ -331,7 +366,7 @@ public class SocketRemoteSiteListener implements RemoteSiteListener {
}
/**
* Returns {@code true} if any related exception (determined by {@link CertificateUtils#isTlsError(Throwable)}) has occurred within the last
* Returns {@code true} if any related exception has occurred within the last
* {@link #EXCEPTION_THRESHOLD_MILLIS} milliseconds. Does not evaluate the error locally,
* simply checks the last time the timestamp was updated.
*

View File

@ -1,127 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.remote
import org.apache.nifi.security.util.KeyStoreUtils
import org.apache.nifi.security.util.KeystoreType
import org.apache.nifi.security.util.SslContextFactory
import org.apache.nifi.security.util.StandardTlsConfiguration
import org.apache.nifi.security.util.TlsConfiguration
import org.apache.nifi.util.NiFiProperties
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.Test
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import javax.net.ssl.SSLContext
import javax.net.ssl.SSLServerSocket
import java.security.Security
import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertTrue
class SocketRemoteSiteListenerTest {
private static final Logger logger = LoggerFactory.getLogger(SocketRemoteSiteListenerTest.class)
private static final String KEYSTORE_PATH = "src/test/resources/localhost-ks.jks"
private static final String KEYSTORE_PASSWORD = "OI7kMpWzzVNVx/JGhTL/0uO4+PWpGJ46uZ/pfepbkwI"
private static final KeystoreType KEYSTORE_TYPE = KeystoreType.JKS
private static final String TRUSTSTORE_PATH = "src/test/resources/localhost-ts.jks"
private static final String TRUSTSTORE_PASSWORD = "wAOR0nQJ2EXvOP0JZ2EaqA/n7W69ILS4sWAHghmIWCc"
private static final KeystoreType TRUSTSTORE_TYPE = KeystoreType.JKS
private static final String HOSTNAME = "localhost"
private static final int PORT = 0
// The nifi.properties in src/test/resources has 0.x properties and should be removed or updated
private static final Map<String, String> DEFAULT_PROPS = [
(NiFiProperties.SECURITY_KEYSTORE) : KEYSTORE_PATH,
(NiFiProperties.SECURITY_KEYSTORE_PASSWD) : KEYSTORE_PASSWORD,
(NiFiProperties.SECURITY_KEYSTORE_TYPE) : KEYSTORE_TYPE.getType(),
(NiFiProperties.SECURITY_TRUSTSTORE) : TRUSTSTORE_PATH,
(NiFiProperties.SECURITY_TRUSTSTORE_PASSWD): TRUSTSTORE_PASSWORD,
(NiFiProperties.SECURITY_TRUSTSTORE_TYPE) : TRUSTSTORE_TYPE.getType(),
(NiFiProperties.REMOTE_INPUT_HOST): HOSTNAME,
(NiFiProperties.REMOTE_INPUT_PORT): PORT as String,
"nifi.remote.input.secure": "true"
]
private NiFiProperties mockNiFiProperties = NiFiProperties.createBasicNiFiProperties("", DEFAULT_PROPS)
private static TlsConfiguration tlsConfiguration
private static SSLContext sslContext
private SocketRemoteSiteListener srsListener
@BeforeAll
static void setUpOnce() throws Exception {
Security.addProvider(new BouncyCastleProvider())
logger.metaClass.methodMissing = { String name, args ->
logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}")
}
tlsConfiguration = new StandardTlsConfiguration(KEYSTORE_PATH, KEYSTORE_PASSWORD, KEYSTORE_TYPE, TRUSTSTORE_PATH, TRUSTSTORE_PASSWORD, TRUSTSTORE_TYPE)
sslContext = SslContextFactory.createSslContext(tlsConfiguration)
}
@AfterEach
void tearDown() {
if (srsListener) {
srsListener.stop()
}
}
/**
* Asserts that the protocol versions in the parameters object are correct. In recent versions of Java, this enforces order as well, but in older versions, it just enforces presence.
*
* @param enabledProtocols the actual protocols, either in {@code String[]} or {@code Collection<String>} form
* @param expectedProtocols the specific protocol versions to be present (ordered as desired)
*/
static void assertProtocolVersions(def enabledProtocols, def expectedProtocols) {
if (TlsConfiguration.getJavaVersion() > 8) {
assertArrayEquals(expectedProtocols as String[], enabledProtocols)
} else {
assertEquals(expectedProtocols as Set, enabledProtocols as Set)
}
}
@Test
void testShouldCreateSecureServer() {
// Arrange
logger.info("Creating SSL Context from TLS Configuration: ${tlsConfiguration}")
SSLContext sslContext = SslContextFactory.createSslContext(tlsConfiguration)
logger.info("Created SSL Context: ${KeyStoreUtils.sslContextToString(sslContext)}")
srsListener = new SocketRemoteSiteListener(PORT, sslContext, mockNiFiProperties)
// Act
srsListener.start()
// Assert
// serverSocket isn't instance field like CLBS so have to use private method invocation to verify
SSLServerSocket sslServerSocket = srsListener.createServerSocket() as SSLServerSocket
logger.info("Created SSL server socket: ${KeyStoreUtils.sslServerSocketToString(sslServerSocket)}" as String)
assertProtocolVersions(sslServerSocket.enabledProtocols, TlsConfiguration.getCurrentSupportedTlsProtocolVersions())
assertTrue(sslServerSocket.needClientAuth)
}
}

View File

@ -187,6 +187,12 @@
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-xml-processing</artifactId>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-security-identity</artifactId>
<version>2.0.0-SNAPSHOT</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-web-security</artifactId>

View File

@ -101,6 +101,11 @@
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-security-utils</artifactId>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-security-identity</artifactId>
<version>2.0.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-framework-core</artifactId>

View File

@ -32,7 +32,6 @@
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-utils</artifactId>
<version>2.0.0-SNAPSHOT</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>

View File

@ -41,6 +41,11 @@
<artifactId>nifi-security-utils</artifactId>
<version>2.0.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-security-identity</artifactId>
<version>2.0.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-properties</artifactId>

View File

@ -59,6 +59,11 @@
<artifactId>nifi-event-transport</artifactId>
<version>2.0.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-security-utils-api</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-mock</artifactId>

View File

@ -57,6 +57,12 @@
<artifactId>nifi-utils</artifactId>
<version>2.0.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-security-utils</artifactId>
<version>2.0.0-SNAPSHOT</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-mock</artifactId>