diff --git a/nifi-commons/nifi-security-utils/pom.xml b/nifi-commons/nifi-security-utils/pom.xml
index 76d3c9aa88..1381e82d9d 100644
--- a/nifi-commons/nifi-security-utils/pom.xml
+++ b/nifi-commons/nifi-security-utils/pom.xml
@@ -30,6 +30,16 @@
org.apache.commons
commons-lang3
+
+ org.bouncycastle
+ bcprov-jdk15on
+ test
+
+
+ org.bouncycastle
+ bcpkix-jdk15on
+ test
+
diff --git a/nifi-commons/nifi-security-utils/src/main/java/org/apache/nifi/security/util/CertificateUtils.java b/nifi-commons/nifi-security-utils/src/main/java/org/apache/nifi/security/util/CertificateUtils.java
index cf9a538c27..b3321f772e 100644
--- a/nifi-commons/nifi-security-utils/src/main/java/org/apache/nifi/security/util/CertificateUtils.java
+++ b/nifi-commons/nifi-security-utils/src/main/java/org/apache/nifi/security/util/CertificateUtils.java
@@ -30,6 +30,9 @@ import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
+import javax.naming.InvalidNameException;
+import javax.naming.ldap.LdapName;
+import javax.naming.ldap.Rdn;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLSocket;
import org.apache.commons.lang3.StringUtils;
@@ -39,6 +42,25 @@ import org.slf4j.LoggerFactory;
public final class CertificateUtils {
private static final Logger logger = LoggerFactory.getLogger(CertificateUtils.class);
+ private static final String PEER_NOT_AUTHENTICATED_MSG = "peer not authenticated";
+
+ public enum ClientAuth {
+ NONE(0, "none"),
+ WANT(1, "want"),
+ NEED(2, "need");
+
+ private int value;
+ private String description;
+
+ ClientAuth(int value, String description) {
+ this.value = value;
+ this.description = description;
+ }
+
+ public String toString() {
+ return "Client Auth: " + this.description + " (" + this.value + ")";
+ }
+ }
/**
* Returns true if the given keystore can be loaded using the given keystore type and password. Returns false otherwise.
@@ -148,20 +170,43 @@ public final class CertificateUtils {
String dn = null;
if (socket instanceof SSLSocket) {
final SSLSocket sslSocket = (SSLSocket) socket;
- try {
- final Certificate[] certChains = sslSocket.getSession().getPeerCertificates();
- if (certChains != null && certChains.length > 0) {
- X509Certificate x509Certificate = convertAbstractX509Certificate(certChains[0]);
- dn = x509Certificate.getSubjectDN().getName().trim();
+
+ /** The clientAuth value can be "need", "want", or "none"
+ * A client must send client certificates for need, should for want, and will not for none.
+ * This method should throw an exception if none are provided for need, return null if none are provided for want, and return null (without checking) for none.
+ */
+
+ ClientAuth clientAuth = getClientAuthStatus(sslSocket);
+ logger.debug("SSL Socket client auth status: {}", clientAuth);
+
+ if (clientAuth != ClientAuth.NONE) {
+ try {
+ final Certificate[] certChains = sslSocket.getSession().getPeerCertificates();
+ if (certChains != null && certChains.length > 0) {
+ X509Certificate x509Certificate = convertAbstractX509Certificate(certChains[0]);
+ dn = x509Certificate.getSubjectDN().getName().trim();
+ }
+ } catch (SSLPeerUnverifiedException e) {
+ if (e.getMessage().equals(PEER_NOT_AUTHENTICATED_MSG)) {
+ logger.error("The incoming request did not contain client certificates and thus the DN cannot" +
+ " be extracted. Check that the other endpoint is providing a complete client certificate chain");
+ }
+ if (clientAuth == ClientAuth.WANT) {
+ logger.warn("Suppressing missing client certificate exception because client auth is set to 'want'");
+ return dn;
+ }
+ throw new CertificateException(e);
}
- } catch (SSLPeerUnverifiedException e) {
- throw new CertificateException(e);
}
}
return dn;
}
+ private static ClientAuth getClientAuthStatus(SSLSocket sslSocket) {
+ return sslSocket.getNeedClientAuth() ? ClientAuth.NEED : sslSocket.getWantClientAuth() ? ClientAuth.WANT : ClientAuth.NONE;
+ }
+
/**
* Accepts a legacy {@link javax.security.cert.X509Certificate} and returns an {@link X509Certificate}. The {@code javax.*} package certificate classes are for legacy compatibility and should
* not be used for new development.
@@ -213,6 +258,45 @@ public final class CertificateUtils {
}
}
+ /**
+ * Returns true if the two provided DNs are equivalent, regardless of the order of the elements. Returns false if one or both are invalid DNs.
+ *
+ * Example:
+ *
+ * CN=test1, O=testOrg, C=US compared to CN=test1, O=testOrg, C=US -> true
+ * CN=test1, O=testOrg, C=US compared to O=testOrg, CN=test1, C=US -> true
+ * CN=test1, O=testOrg, C=US compared to CN=test2, O=testOrg, C=US -> false
+ * CN=test1, O=testOrg, C=US compared to O=testOrg, CN=test2, C=US -> false
+ * CN=test1, O=testOrg, C=US compared to -> false
+ * compared to -> true
+ *
+ * @param dn1 the first DN to compare
+ * @param dn2 the second DN to compare
+ * @return true if the DNs are equivalent, false otherwise
+ */
+ public static boolean compareDNs(String dn1, String dn2) {
+ if (dn1 == null) {
+ dn1 = "";
+ }
+
+ if (dn2 == null) {
+ dn2 = "";
+ }
+
+ if (StringUtils.isEmpty(dn1) || StringUtils.isEmpty(dn2)) {
+ return dn1.equals(dn2);
+ }
+ try {
+ List rdn1 = new LdapName(dn1).getRdns();
+ List rdn2 = new LdapName(dn2).getRdns();
+
+ return rdn1.size() == rdn2.size() && rdn1.containsAll(rdn2);
+ } catch (InvalidNameException e) {
+ logger.warn("Cannot compare DNs: {} and {} because one or both is not a valid DN", dn1, dn2);
+ return false;
+ }
+ }
+
private CertificateUtils() {
}
}
diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/security/util/CertificateUtilsTest.groovy b/nifi-commons/nifi-security-utils/src/test/groovy/org/apache/nifi/security/util/CertificateUtilsTest.groovy
similarity index 63%
rename from nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/security/util/CertificateUtilsTest.groovy
rename to nifi-commons/nifi-security-utils/src/test/groovy/org/apache/nifi/security/util/CertificateUtilsTest.groovy
index 2be2e16373..2d00a256a2 100644
--- a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/security/util/CertificateUtilsTest.groovy
+++ b/nifi-commons/nifi-security-utils/src/test/groovy/org/apache/nifi/security/util/CertificateUtilsTest.groovy
@@ -38,6 +38,9 @@ import org.junit.runners.JUnit4
import org.slf4j.Logger
import org.slf4j.LoggerFactory
+import javax.net.ssl.SSLPeerUnverifiedException
+import javax.net.ssl.SSLSession
+import javax.net.ssl.SSLSocket
import java.security.InvalidKeyException
import java.security.KeyPair
import java.security.KeyPairGenerator
@@ -272,4 +275,179 @@ class CertificateUtilsTest extends GroovyTestCase {
assert convertedCertificate instanceof X509Certificate
assert convertedCertificate == EXPECTED_NEW_CERTIFICATE
}
+
+ @Test
+ void testShouldDetermineClientAuthStatusFromSocket() {
+ // Arrange
+ SSLSocket needSocket = [getNeedClientAuth: { -> true }] as SSLSocket
+ SSLSocket wantSocket = [getNeedClientAuth: { -> false }, getWantClientAuth: { -> true }] as SSLSocket
+ SSLSocket noneSocket = [getNeedClientAuth: { -> false }, getWantClientAuth: { -> false }] as SSLSocket
+
+ // Act
+ CertificateUtils.ClientAuth needClientAuthStatus = CertificateUtils.getClientAuthStatus(needSocket)
+ logger.info("Client auth (needSocket): ${needClientAuthStatus}")
+ CertificateUtils.ClientAuth wantClientAuthStatus = CertificateUtils.getClientAuthStatus(wantSocket)
+ logger.info("Client auth (wantSocket): ${wantClientAuthStatus}")
+ CertificateUtils.ClientAuth noneClientAuthStatus = CertificateUtils.getClientAuthStatus(noneSocket)
+ logger.info("Client auth (noneSocket): ${noneClientAuthStatus}")
+
+ // Assert
+ assert needClientAuthStatus == CertificateUtils.ClientAuth.NEED
+ assert wantClientAuthStatus == CertificateUtils.ClientAuth.WANT
+ assert noneClientAuthStatus == CertificateUtils.ClientAuth.NONE
+ }
+
+ @Test
+ void testShouldNotExtractClientCertificatesFromSSLSocketWithClientAuthNone() {
+ // Arrange
+ SSLSocket mockSocket = [
+ getNeedClientAuth: { -> false },
+ getWantClientAuth: { -> false }
+ ] as SSLSocket
+
+ // Act
+ String clientDN = CertificateUtils.extractClientDNFromSSLSocket(mockSocket)
+ logger.info("Extracted client DN: ${clientDN}")
+
+ // Assert
+ assert !clientDN
+ }
+
+ @Test
+ void testShouldExtractClientCertificatesFromSSLSocketWithClientAuthWant() {
+ // Arrange
+ final String EXPECTED_DN = "CN=client.nifi.apache.org,OU=Security,O=Apache,ST=CA,C=US"
+ Certificate[] certificateChain = generateCertificateChain(EXPECTED_DN)
+ logger.info("Expected DN: ${EXPECTED_DN}")
+ logger.info("Expected certificate chain: ${certificateChain.collect { (it as X509Certificate).getSubjectDN().name }.join(" issued by ")}")
+
+ SSLSession mockSession = [getPeerCertificates: { -> certificateChain }] as SSLSession
+
+ SSLSocket mockSocket = [
+ getNeedClientAuth: { -> false },
+ getWantClientAuth: { -> true },
+ getSession : { -> mockSession }
+ ] as SSLSocket
+
+ // Act
+ String clientDN = CertificateUtils.extractClientDNFromSSLSocket(mockSocket)
+ logger.info("Extracted client DN: ${clientDN}")
+
+ // Assert
+ assert CertificateUtils.compareDNs(clientDN, EXPECTED_DN)
+ }
+
+ @Test
+ void testShouldHandleFailureToExtractClientCertificatesFromSSLSocketWithClientAuthWant() {
+ // Arrange
+ SSLSession mockSession = [getPeerCertificates: { -> throw new SSLPeerUnverifiedException("peer not authenticated") }] as SSLSession
+
+ SSLSocket mockSocket = [
+ getNeedClientAuth: { -> false },
+ getWantClientAuth: { -> true },
+ getSession : { -> mockSession }
+ ] as SSLSocket
+
+ // Act
+ String clientDN = CertificateUtils.extractClientDNFromSSLSocket(mockSocket)
+ logger.info("Extracted client DN: ${clientDN}")
+
+ // Assert
+ assert CertificateUtils.compareDNs(clientDN, null)
+ }
+
+
+ @Test
+ void testShouldExtractClientCertificatesFromSSLSocketWithClientAuthNeed() {
+ // Arrange
+ final String EXPECTED_DN = "CN=client.nifi.apache.org,OU=Security,O=Apache,ST=CA,C=US"
+ Certificate[] certificateChain = generateCertificateChain(EXPECTED_DN)
+ logger.info("Expected DN: ${EXPECTED_DN}")
+ logger.info("Expected certificate chain: ${certificateChain.collect { (it as X509Certificate).getSubjectDN().name }.join(" issued by ")}")
+
+ SSLSession mockSession = [getPeerCertificates: { -> certificateChain }] as SSLSession
+
+ SSLSocket mockSocket = [
+ getNeedClientAuth: { -> true },
+ getWantClientAuth: { -> false },
+ getSession : { -> mockSession }
+ ] as SSLSocket
+
+ // Act
+ String clientDN = CertificateUtils.extractClientDNFromSSLSocket(mockSocket)
+ logger.info("Extracted client DN: ${clientDN}")
+
+ // Assert
+ assert CertificateUtils.compareDNs(clientDN, EXPECTED_DN)
+ }
+
+ @Test
+ void testShouldHandleFailureToExtractClientCertificatesFromSSLSocketWithClientAuthNeed() {
+ // Arrange
+ SSLSession mockSession = [getPeerCertificates: { -> throw new SSLPeerUnverifiedException("peer not authenticated") }] as SSLSession
+
+ SSLSocket mockSocket = [
+ getNeedClientAuth: { -> true },
+ getWantClientAuth: { -> false },
+ getSession : { -> mockSession }
+ ] as SSLSocket
+
+ // Act
+ def msg = shouldFail(CertificateException) {
+ String clientDN = CertificateUtils.extractClientDNFromSSLSocket(mockSocket)
+ logger.info("Extracted client DN: ${clientDN}")
+ }
+
+ // Assert
+ assert msg =~ "peer not authenticated"
+ }
+
+ @Test
+ void testShouldCompareDNs() {
+ // Arrange
+ final String DN_1_ORDERED = "CN=test1.nifi.apache.org, OU=Apache NiFi, O=Apache, ST=California, C=US"
+ logger.info("DN 1 Ordered : ${DN_1_ORDERED}")
+ final String DN_1_REVERSED = DN_1_ORDERED.split(", ").reverse().join(", ")
+ logger.info("DN 1 Reversed: ${DN_1_REVERSED}")
+
+ final String DN_2_ORDERED = "CN=test2.nifi.apache.org, OU=Apache NiFi, O=Apache, ST=California, C=US"
+ logger.info("DN 2 Ordered : ${DN_2_ORDERED}")
+ final String DN_2_REVERSED = DN_2_ORDERED.split(", ").reverse().join(", ")
+ logger.info("DN 2 Reversed: ${DN_2_REVERSED}")
+
+ // Act
+
+ // True
+ boolean dn1MatchesSelf = CertificateUtils.compareDNs(DN_1_ORDERED, DN_1_ORDERED)
+ logger.matches("DN 1, DN 1: ${dn1MatchesSelf}")
+
+ boolean dn1MatchesReversed = CertificateUtils.compareDNs(DN_1_ORDERED, DN_1_REVERSED)
+ logger.matches("DN 1, DN 1 (R): ${dn1MatchesReversed}")
+
+ boolean emptyMatchesEmpty = CertificateUtils.compareDNs("", "")
+ logger.matches("empty, empty: ${emptyMatchesEmpty}")
+
+ boolean nullMatchesNull = CertificateUtils.compareDNs(null, null)
+ logger.matches("null, null: ${nullMatchesNull}")
+
+ // False
+ boolean dn1MatchesDn2 = CertificateUtils.compareDNs(DN_1_ORDERED, DN_2_ORDERED)
+ logger.matches("DN 1, DN 2: ${dn1MatchesDn2}")
+
+ boolean dn1MatchesDn2Reversed = CertificateUtils.compareDNs(DN_1_ORDERED, DN_2_REVERSED)
+ logger.matches("DN 1, DN 2 (R): ${dn1MatchesDn2Reversed}")
+
+ boolean dn1MatchesEmpty = CertificateUtils.compareDNs(DN_1_ORDERED, "")
+ logger.matches("DN 1, empty: ${dn1MatchesEmpty}")
+
+ // Assert
+ assert dn1MatchesSelf
+ assert dn1MatchesReversed
+ assert emptyMatchesEmpty
+ assert nullMatchesNull
+
+ assert !dn1MatchesDn2
+ assert !dn1MatchesDn2Reversed
+ assert !dn1MatchesEmpty
+ }
}