diff --git a/nifi-commons/nifi-security-utils/pom.xml b/nifi-commons/nifi-security-utils/pom.xml index 76d3c9aa88..1381e82d9d 100644 --- a/nifi-commons/nifi-security-utils/pom.xml +++ b/nifi-commons/nifi-security-utils/pom.xml @@ -30,6 +30,16 @@ org.apache.commons commons-lang3 + + org.bouncycastle + bcprov-jdk15on + test + + + org.bouncycastle + bcpkix-jdk15on + test + diff --git a/nifi-commons/nifi-security-utils/src/main/java/org/apache/nifi/security/util/CertificateUtils.java b/nifi-commons/nifi-security-utils/src/main/java/org/apache/nifi/security/util/CertificateUtils.java index cf9a538c27..b3321f772e 100644 --- a/nifi-commons/nifi-security-utils/src/main/java/org/apache/nifi/security/util/CertificateUtils.java +++ b/nifi-commons/nifi-security-utils/src/main/java/org/apache/nifi/security/util/CertificateUtils.java @@ -30,6 +30,9 @@ import java.security.cert.X509Certificate; import java.util.ArrayList; import java.util.Collection; import java.util.List; +import javax.naming.InvalidNameException; +import javax.naming.ldap.LdapName; +import javax.naming.ldap.Rdn; import javax.net.ssl.SSLPeerUnverifiedException; import javax.net.ssl.SSLSocket; import org.apache.commons.lang3.StringUtils; @@ -39,6 +42,25 @@ import org.slf4j.LoggerFactory; public final class CertificateUtils { private static final Logger logger = LoggerFactory.getLogger(CertificateUtils.class); + private static final String PEER_NOT_AUTHENTICATED_MSG = "peer not authenticated"; + + public enum ClientAuth { + NONE(0, "none"), + WANT(1, "want"), + NEED(2, "need"); + + private int value; + private String description; + + ClientAuth(int value, String description) { + this.value = value; + this.description = description; + } + + public String toString() { + return "Client Auth: " + this.description + " (" + this.value + ")"; + } + } /** * Returns true if the given keystore can be loaded using the given keystore type and password. Returns false otherwise. @@ -148,20 +170,43 @@ public final class CertificateUtils { String dn = null; if (socket instanceof SSLSocket) { final SSLSocket sslSocket = (SSLSocket) socket; - try { - final Certificate[] certChains = sslSocket.getSession().getPeerCertificates(); - if (certChains != null && certChains.length > 0) { - X509Certificate x509Certificate = convertAbstractX509Certificate(certChains[0]); - dn = x509Certificate.getSubjectDN().getName().trim(); + + /** The clientAuth value can be "need", "want", or "none" + * A client must send client certificates for need, should for want, and will not for none. + * This method should throw an exception if none are provided for need, return null if none are provided for want, and return null (without checking) for none. + */ + + ClientAuth clientAuth = getClientAuthStatus(sslSocket); + logger.debug("SSL Socket client auth status: {}", clientAuth); + + if (clientAuth != ClientAuth.NONE) { + try { + final Certificate[] certChains = sslSocket.getSession().getPeerCertificates(); + if (certChains != null && certChains.length > 0) { + X509Certificate x509Certificate = convertAbstractX509Certificate(certChains[0]); + dn = x509Certificate.getSubjectDN().getName().trim(); + } + } catch (SSLPeerUnverifiedException e) { + if (e.getMessage().equals(PEER_NOT_AUTHENTICATED_MSG)) { + logger.error("The incoming request did not contain client certificates and thus the DN cannot" + + " be extracted. Check that the other endpoint is providing a complete client certificate chain"); + } + if (clientAuth == ClientAuth.WANT) { + logger.warn("Suppressing missing client certificate exception because client auth is set to 'want'"); + return dn; + } + throw new CertificateException(e); } - } catch (SSLPeerUnverifiedException e) { - throw new CertificateException(e); } } return dn; } + private static ClientAuth getClientAuthStatus(SSLSocket sslSocket) { + return sslSocket.getNeedClientAuth() ? ClientAuth.NEED : sslSocket.getWantClientAuth() ? ClientAuth.WANT : ClientAuth.NONE; + } + /** * Accepts a legacy {@link javax.security.cert.X509Certificate} and returns an {@link X509Certificate}. The {@code javax.*} package certificate classes are for legacy compatibility and should * not be used for new development. @@ -213,6 +258,45 @@ public final class CertificateUtils { } } + /** + * Returns true if the two provided DNs are equivalent, regardless of the order of the elements. Returns false if one or both are invalid DNs. + * + * Example: + * + * CN=test1, O=testOrg, C=US compared to CN=test1, O=testOrg, C=US -> true + * CN=test1, O=testOrg, C=US compared to O=testOrg, CN=test1, C=US -> true + * CN=test1, O=testOrg, C=US compared to CN=test2, O=testOrg, C=US -> false + * CN=test1, O=testOrg, C=US compared to O=testOrg, CN=test2, C=US -> false + * CN=test1, O=testOrg, C=US compared to -> false + * compared to -> true + * + * @param dn1 the first DN to compare + * @param dn2 the second DN to compare + * @return true if the DNs are equivalent, false otherwise + */ + public static boolean compareDNs(String dn1, String dn2) { + if (dn1 == null) { + dn1 = ""; + } + + if (dn2 == null) { + dn2 = ""; + } + + if (StringUtils.isEmpty(dn1) || StringUtils.isEmpty(dn2)) { + return dn1.equals(dn2); + } + try { + List rdn1 = new LdapName(dn1).getRdns(); + List rdn2 = new LdapName(dn2).getRdns(); + + return rdn1.size() == rdn2.size() && rdn1.containsAll(rdn2); + } catch (InvalidNameException e) { + logger.warn("Cannot compare DNs: {} and {} because one or both is not a valid DN", dn1, dn2); + return false; + } + } + private CertificateUtils() { } } diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/security/util/CertificateUtilsTest.groovy b/nifi-commons/nifi-security-utils/src/test/groovy/org/apache/nifi/security/util/CertificateUtilsTest.groovy similarity index 63% rename from nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/security/util/CertificateUtilsTest.groovy rename to nifi-commons/nifi-security-utils/src/test/groovy/org/apache/nifi/security/util/CertificateUtilsTest.groovy index 2be2e16373..2d00a256a2 100644 --- a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/security/util/CertificateUtilsTest.groovy +++ b/nifi-commons/nifi-security-utils/src/test/groovy/org/apache/nifi/security/util/CertificateUtilsTest.groovy @@ -38,6 +38,9 @@ import org.junit.runners.JUnit4 import org.slf4j.Logger import org.slf4j.LoggerFactory +import javax.net.ssl.SSLPeerUnverifiedException +import javax.net.ssl.SSLSession +import javax.net.ssl.SSLSocket import java.security.InvalidKeyException import java.security.KeyPair import java.security.KeyPairGenerator @@ -272,4 +275,179 @@ class CertificateUtilsTest extends GroovyTestCase { assert convertedCertificate instanceof X509Certificate assert convertedCertificate == EXPECTED_NEW_CERTIFICATE } + + @Test + void testShouldDetermineClientAuthStatusFromSocket() { + // Arrange + SSLSocket needSocket = [getNeedClientAuth: { -> true }] as SSLSocket + SSLSocket wantSocket = [getNeedClientAuth: { -> false }, getWantClientAuth: { -> true }] as SSLSocket + SSLSocket noneSocket = [getNeedClientAuth: { -> false }, getWantClientAuth: { -> false }] as SSLSocket + + // Act + CertificateUtils.ClientAuth needClientAuthStatus = CertificateUtils.getClientAuthStatus(needSocket) + logger.info("Client auth (needSocket): ${needClientAuthStatus}") + CertificateUtils.ClientAuth wantClientAuthStatus = CertificateUtils.getClientAuthStatus(wantSocket) + logger.info("Client auth (wantSocket): ${wantClientAuthStatus}") + CertificateUtils.ClientAuth noneClientAuthStatus = CertificateUtils.getClientAuthStatus(noneSocket) + logger.info("Client auth (noneSocket): ${noneClientAuthStatus}") + + // Assert + assert needClientAuthStatus == CertificateUtils.ClientAuth.NEED + assert wantClientAuthStatus == CertificateUtils.ClientAuth.WANT + assert noneClientAuthStatus == CertificateUtils.ClientAuth.NONE + } + + @Test + void testShouldNotExtractClientCertificatesFromSSLSocketWithClientAuthNone() { + // Arrange + SSLSocket mockSocket = [ + getNeedClientAuth: { -> false }, + getWantClientAuth: { -> false } + ] as SSLSocket + + // Act + String clientDN = CertificateUtils.extractClientDNFromSSLSocket(mockSocket) + logger.info("Extracted client DN: ${clientDN}") + + // Assert + assert !clientDN + } + + @Test + void testShouldExtractClientCertificatesFromSSLSocketWithClientAuthWant() { + // Arrange + final String EXPECTED_DN = "CN=client.nifi.apache.org,OU=Security,O=Apache,ST=CA,C=US" + Certificate[] certificateChain = generateCertificateChain(EXPECTED_DN) + logger.info("Expected DN: ${EXPECTED_DN}") + logger.info("Expected certificate chain: ${certificateChain.collect { (it as X509Certificate).getSubjectDN().name }.join(" issued by ")}") + + SSLSession mockSession = [getPeerCertificates: { -> certificateChain }] as SSLSession + + SSLSocket mockSocket = [ + getNeedClientAuth: { -> false }, + getWantClientAuth: { -> true }, + getSession : { -> mockSession } + ] as SSLSocket + + // Act + String clientDN = CertificateUtils.extractClientDNFromSSLSocket(mockSocket) + logger.info("Extracted client DN: ${clientDN}") + + // Assert + assert CertificateUtils.compareDNs(clientDN, EXPECTED_DN) + } + + @Test + void testShouldHandleFailureToExtractClientCertificatesFromSSLSocketWithClientAuthWant() { + // Arrange + SSLSession mockSession = [getPeerCertificates: { -> throw new SSLPeerUnverifiedException("peer not authenticated") }] as SSLSession + + SSLSocket mockSocket = [ + getNeedClientAuth: { -> false }, + getWantClientAuth: { -> true }, + getSession : { -> mockSession } + ] as SSLSocket + + // Act + String clientDN = CertificateUtils.extractClientDNFromSSLSocket(mockSocket) + logger.info("Extracted client DN: ${clientDN}") + + // Assert + assert CertificateUtils.compareDNs(clientDN, null) + } + + + @Test + void testShouldExtractClientCertificatesFromSSLSocketWithClientAuthNeed() { + // Arrange + final String EXPECTED_DN = "CN=client.nifi.apache.org,OU=Security,O=Apache,ST=CA,C=US" + Certificate[] certificateChain = generateCertificateChain(EXPECTED_DN) + logger.info("Expected DN: ${EXPECTED_DN}") + logger.info("Expected certificate chain: ${certificateChain.collect { (it as X509Certificate).getSubjectDN().name }.join(" issued by ")}") + + SSLSession mockSession = [getPeerCertificates: { -> certificateChain }] as SSLSession + + SSLSocket mockSocket = [ + getNeedClientAuth: { -> true }, + getWantClientAuth: { -> false }, + getSession : { -> mockSession } + ] as SSLSocket + + // Act + String clientDN = CertificateUtils.extractClientDNFromSSLSocket(mockSocket) + logger.info("Extracted client DN: ${clientDN}") + + // Assert + assert CertificateUtils.compareDNs(clientDN, EXPECTED_DN) + } + + @Test + void testShouldHandleFailureToExtractClientCertificatesFromSSLSocketWithClientAuthNeed() { + // Arrange + SSLSession mockSession = [getPeerCertificates: { -> throw new SSLPeerUnverifiedException("peer not authenticated") }] as SSLSession + + SSLSocket mockSocket = [ + getNeedClientAuth: { -> true }, + getWantClientAuth: { -> false }, + getSession : { -> mockSession } + ] as SSLSocket + + // Act + def msg = shouldFail(CertificateException) { + String clientDN = CertificateUtils.extractClientDNFromSSLSocket(mockSocket) + logger.info("Extracted client DN: ${clientDN}") + } + + // Assert + assert msg =~ "peer not authenticated" + } + + @Test + void testShouldCompareDNs() { + // Arrange + final String DN_1_ORDERED = "CN=test1.nifi.apache.org, OU=Apache NiFi, O=Apache, ST=California, C=US" + logger.info("DN 1 Ordered : ${DN_1_ORDERED}") + final String DN_1_REVERSED = DN_1_ORDERED.split(", ").reverse().join(", ") + logger.info("DN 1 Reversed: ${DN_1_REVERSED}") + + final String DN_2_ORDERED = "CN=test2.nifi.apache.org, OU=Apache NiFi, O=Apache, ST=California, C=US" + logger.info("DN 2 Ordered : ${DN_2_ORDERED}") + final String DN_2_REVERSED = DN_2_ORDERED.split(", ").reverse().join(", ") + logger.info("DN 2 Reversed: ${DN_2_REVERSED}") + + // Act + + // True + boolean dn1MatchesSelf = CertificateUtils.compareDNs(DN_1_ORDERED, DN_1_ORDERED) + logger.matches("DN 1, DN 1: ${dn1MatchesSelf}") + + boolean dn1MatchesReversed = CertificateUtils.compareDNs(DN_1_ORDERED, DN_1_REVERSED) + logger.matches("DN 1, DN 1 (R): ${dn1MatchesReversed}") + + boolean emptyMatchesEmpty = CertificateUtils.compareDNs("", "") + logger.matches("empty, empty: ${emptyMatchesEmpty}") + + boolean nullMatchesNull = CertificateUtils.compareDNs(null, null) + logger.matches("null, null: ${nullMatchesNull}") + + // False + boolean dn1MatchesDn2 = CertificateUtils.compareDNs(DN_1_ORDERED, DN_2_ORDERED) + logger.matches("DN 1, DN 2: ${dn1MatchesDn2}") + + boolean dn1MatchesDn2Reversed = CertificateUtils.compareDNs(DN_1_ORDERED, DN_2_REVERSED) + logger.matches("DN 1, DN 2 (R): ${dn1MatchesDn2Reversed}") + + boolean dn1MatchesEmpty = CertificateUtils.compareDNs(DN_1_ORDERED, "") + logger.matches("DN 1, empty: ${dn1MatchesEmpty}") + + // Assert + assert dn1MatchesSelf + assert dn1MatchesReversed + assert emptyMatchesEmpty + assert nullMatchesNull + + assert !dn1MatchesDn2 + assert !dn1MatchesDn2Reversed + assert !dn1MatchesEmpty + } }