NIFI-2186 Refactored CertificateUtils to separate logic for DN extraction from server/client sockets. Added logic to detect server/client mode encapsulated in exposed method.

Added unit tests for DN extraction.
Corrected typo in Javadoc.
Switched server/client socket logic for certificate extraction -- when the local socket is in client/server mode, the peer is necessarily the inverse.
Fixed unit tests.
Moved lazy-loading authentication access out of isDebugEnabled() control branch.
This closes #622
This commit is contained in:
Andy LoPresto 2016-07-04 21:05:58 -07:00 committed by Matt Gilman
parent 039fd70ded
commit 4b9df7d1e2
5 changed files with 149 additions and 17 deletions

View File

@ -166,11 +166,51 @@ public final class CertificateUtils {
return result; return result;
} }
public static String extractClientDNFromSSLSocket(Socket socket) throws CertificateException { /**
* Returns the DN extracted from the peer certificate (the server DN if run on the client; the client DN (if available) if run on the server).
*
* If the client auth setting is WANT or NONE and a client certificate is not present, this method will return {@code null}.
* If the client auth is NEED, it will throw a {@link CertificateException}.
*
* @param socket the SSL Socket
* @return the extracted DN
* @throws CertificateException if there is a problem parsing the certificate
*/
public static String extractPeerDNFromSSLSocket(Socket socket) throws CertificateException {
String dn = null; String dn = null;
if (socket instanceof SSLSocket) { if (socket instanceof SSLSocket) {
final SSLSocket sslSocket = (SSLSocket) socket; final SSLSocket sslSocket = (SSLSocket) socket;
boolean clientMode = sslSocket.getUseClientMode();
logger.debug("SSL Socket in {} mode", clientMode ? "client" : "server");
ClientAuth clientAuth = getClientAuthStatus(sslSocket);
logger.debug("SSL Socket client auth status: {}", clientAuth);
if (clientMode) {
logger.debug("This socket is in client mode, so attempting to extract certificate from remote 'server' socket");
dn = extractPeerDNFromServerSSLSocket(sslSocket);
} else {
logger.debug("This socket is in server mode, so attempting to extract certificate from remote 'client' socket");
dn = extractPeerDNFromClientSSLSocket(sslSocket);
}
}
return dn;
}
/**
* Returns the DN extracted from the client certificate.
*
* If the client auth setting is WANT or NONE and a certificate is not present (and {@code respectClientAuth} is {@code true}), this method will return {@code null}.
* If the client auth is NEED, it will throw a {@link CertificateException}.
*
* @param sslSocket the SSL Socket
* @return the extracted DN
* @throws CertificateException if there is a problem parsing the certificate
*/
private static String extractPeerDNFromClientSSLSocket(SSLSocket sslSocket) throws CertificateException {
String dn = null;
/** The clientAuth value can be "need", "want", or "none" /** The clientAuth value can be "need", "want", or "none"
* A client must send client certificates for need, should for want, and will not for none. * A client must send client certificates for need, should for want, and will not for none.
* This method should throw an exception if none are provided for need, return null if none are provided for want, and return null (without checking) for none. * This method should throw an exception if none are provided for need, return null if none are provided for want, and return null (without checking) for none.
@ -185,6 +225,7 @@ public final class CertificateUtils {
if (certChains != null && certChains.length > 0) { if (certChains != null && certChains.length > 0) {
X509Certificate x509Certificate = convertAbstractX509Certificate(certChains[0]); X509Certificate x509Certificate = convertAbstractX509Certificate(certChains[0]);
dn = x509Certificate.getSubjectDN().getName().trim(); dn = x509Certificate.getSubjectDN().getName().trim();
logger.debug("Extracted DN={} from client certificate", dn);
} }
} catch (SSLPeerUnverifiedException e) { } catch (SSLPeerUnverifiedException e) {
if (e.getMessage().equals(PEER_NOT_AUTHENTICATED_MSG)) { if (e.getMessage().equals(PEER_NOT_AUTHENTICATED_MSG)) {
@ -198,8 +239,35 @@ public final class CertificateUtils {
throw new CertificateException(e); throw new CertificateException(e);
} }
} }
return dn;
} }
/**
* Returns the DN extracted from the server certificate.
*
* @param socket the SSL Socket
* @return the extracted DN
* @throws CertificateException if there is a problem parsing the certificate
*/
private static String extractPeerDNFromServerSSLSocket(Socket socket) throws CertificateException {
String dn = null;
if (socket instanceof SSLSocket) {
final SSLSocket sslSocket = (SSLSocket) socket;
try {
final Certificate[] certChains = sslSocket.getSession().getPeerCertificates();
if (certChains != null && certChains.length > 0) {
X509Certificate x509Certificate = convertAbstractX509Certificate(certChains[0]);
dn = x509Certificate.getSubjectDN().getName().trim();
logger.debug("Extracted DN={} from server certificate", dn);
}
} catch (SSLPeerUnverifiedException e) {
if (e.getMessage().equals(PEER_NOT_AUTHENTICATED_MSG)) {
logger.error("The server did not present a certificate and thus the DN cannot" +
" be extracted. Check that the other endpoint is providing a complete certificate chain");
}
throw new CertificateException(e);
}
}
return dn; return dn;
} }

View File

@ -297,16 +297,67 @@ class CertificateUtilsTest extends GroovyTestCase {
assert noneClientAuthStatus == CertificateUtils.ClientAuth.NONE assert noneClientAuthStatus == CertificateUtils.ClientAuth.NONE
} }
@Test @Test
void testShouldNotExtractClientCertificatesFromSSLSocketWithClientAuthNone() { void testShouldExtractClientCertificatesFromSSLServerSocketWithAnyClientAuth() {
final String EXPECTED_DN = "CN=ncm.nifi.apache.org,OU=Security,O=Apache,ST=CA,C=US"
Certificate[] certificateChain = generateCertificateChain(EXPECTED_DN)
logger.info("Expected DN: ${EXPECTED_DN}")
logger.info("Expected certificate chain: ${certificateChain.collect { (it as X509Certificate).getSubjectDN().name }.join(" issued by ")}")
SSLSession mockSession = [getPeerCertificates: { -> certificateChain }] as SSLSession
// This socket is in client mode, so the peer ("target") is a server
// Create mock sockets for each possible value of ClientAuth
SSLSocket mockNoneSocket = [
getUseClientMode : { -> true },
getNeedClientAuth: { -> false },
getWantClientAuth: { -> false },
getSession : { -> mockSession }
] as SSLSocket
SSLSocket mockNeedSocket = [
getUseClientMode : { -> true },
getNeedClientAuth: { -> true },
getWantClientAuth: { -> false },
getSession : { -> mockSession }
] as SSLSocket
SSLSocket mockWantSocket = [
getUseClientMode : { -> true },
getNeedClientAuth: { -> false },
getWantClientAuth: { -> true },
getSession : { -> mockSession }
] as SSLSocket
// Act
def resolvedServerDNs = [mockNeedSocket, mockWantSocket, mockNoneSocket].collect { SSLSocket mockSocket ->
logger.info("Running test with socket ClientAuth setting: ${CertificateUtils.getClientAuthStatus(mockSocket)}")
String serverDN = CertificateUtils.extractPeerDNFromSSLSocket(mockNoneSocket)
logger.info("Extracted server DN: ${serverDN}")
serverDN
}
// Assert
assert resolvedServerDNs.every { String serverDN ->
CertificateUtils.compareDNs(serverDN, EXPECTED_DN)
}
}
@Test
void testShouldNotExtractClientCertificatesFromSSLClientSocketWithClientAuthNone() {
// Arrange // Arrange
// This socket is in server mode, so the peer ("target") is a client
SSLSocket mockSocket = [ SSLSocket mockSocket = [
getUseClientMode : { -> false },
getNeedClientAuth: { -> false }, getNeedClientAuth: { -> false },
getWantClientAuth: { -> false } getWantClientAuth: { -> false }
] as SSLSocket ] as SSLSocket
// Act // Act
String clientDN = CertificateUtils.extractClientDNFromSSLSocket(mockSocket) String clientDN = CertificateUtils.extractPeerDNFromSSLSocket(mockSocket)
logger.info("Extracted client DN: ${clientDN}") logger.info("Extracted client DN: ${clientDN}")
// Assert // Assert
@ -314,7 +365,7 @@ class CertificateUtilsTest extends GroovyTestCase {
} }
@Test @Test
void testShouldExtractClientCertificatesFromSSLSocketWithClientAuthWant() { void testShouldExtractClientCertificatesFromSSLClientSocketWithClientAuthWant() {
// Arrange // Arrange
final String EXPECTED_DN = "CN=client.nifi.apache.org,OU=Security,O=Apache,ST=CA,C=US" final String EXPECTED_DN = "CN=client.nifi.apache.org,OU=Security,O=Apache,ST=CA,C=US"
Certificate[] certificateChain = generateCertificateChain(EXPECTED_DN) Certificate[] certificateChain = generateCertificateChain(EXPECTED_DN)
@ -323,14 +374,16 @@ class CertificateUtilsTest extends GroovyTestCase {
SSLSession mockSession = [getPeerCertificates: { -> certificateChain }] as SSLSession SSLSession mockSession = [getPeerCertificates: { -> certificateChain }] as SSLSession
// This socket is in server mode, so the peer ("target") is a client
SSLSocket mockSocket = [ SSLSocket mockSocket = [
getUseClientMode : { -> false },
getNeedClientAuth: { -> false }, getNeedClientAuth: { -> false },
getWantClientAuth: { -> true }, getWantClientAuth: { -> true },
getSession : { -> mockSession } getSession : { -> mockSession }
] as SSLSocket ] as SSLSocket
// Act // Act
String clientDN = CertificateUtils.extractClientDNFromSSLSocket(mockSocket) String clientDN = CertificateUtils.extractPeerDNFromSSLSocket(mockSocket)
logger.info("Extracted client DN: ${clientDN}") logger.info("Extracted client DN: ${clientDN}")
// Assert // Assert
@ -338,18 +391,22 @@ class CertificateUtilsTest extends GroovyTestCase {
} }
@Test @Test
void testShouldHandleFailureToExtractClientCertificatesFromSSLSocketWithClientAuthWant() { void testShouldHandleFailureToExtractClientCertificatesFromSSLClientSocketWithClientAuthWant() {
// Arrange // Arrange
SSLSession mockSession = [getPeerCertificates: { -> throw new SSLPeerUnverifiedException("peer not authenticated") }] as SSLSession SSLSession mockSession = [getPeerCertificates: { ->
throw new SSLPeerUnverifiedException("peer not authenticated")
}] as SSLSession
// This socket is in server mode, so the peer ("target") is a client
SSLSocket mockSocket = [ SSLSocket mockSocket = [
getUseClientMode : { -> false },
getNeedClientAuth: { -> false }, getNeedClientAuth: { -> false },
getWantClientAuth: { -> true }, getWantClientAuth: { -> true },
getSession : { -> mockSession } getSession : { -> mockSession }
] as SSLSocket ] as SSLSocket
// Act // Act
String clientDN = CertificateUtils.extractClientDNFromSSLSocket(mockSocket) String clientDN = CertificateUtils.extractPeerDNFromSSLSocket(mockSocket)
logger.info("Extracted client DN: ${clientDN}") logger.info("Extracted client DN: ${clientDN}")
// Assert // Assert
@ -358,7 +415,7 @@ class CertificateUtilsTest extends GroovyTestCase {
@Test @Test
void testShouldExtractClientCertificatesFromSSLSocketWithClientAuthNeed() { void testShouldExtractClientCertificatesFromSSLClientSocketWithClientAuthNeed() {
// Arrange // Arrange
final String EXPECTED_DN = "CN=client.nifi.apache.org,OU=Security,O=Apache,ST=CA,C=US" final String EXPECTED_DN = "CN=client.nifi.apache.org,OU=Security,O=Apache,ST=CA,C=US"
Certificate[] certificateChain = generateCertificateChain(EXPECTED_DN) Certificate[] certificateChain = generateCertificateChain(EXPECTED_DN)
@ -367,14 +424,16 @@ class CertificateUtilsTest extends GroovyTestCase {
SSLSession mockSession = [getPeerCertificates: { -> certificateChain }] as SSLSession SSLSession mockSession = [getPeerCertificates: { -> certificateChain }] as SSLSession
// This socket is in server mode, so the peer ("target") is a client
SSLSocket mockSocket = [ SSLSocket mockSocket = [
getUseClientMode : { -> false },
getNeedClientAuth: { -> true }, getNeedClientAuth: { -> true },
getWantClientAuth: { -> false }, getWantClientAuth: { -> false },
getSession : { -> mockSession } getSession : { -> mockSession }
] as SSLSocket ] as SSLSocket
// Act // Act
String clientDN = CertificateUtils.extractClientDNFromSSLSocket(mockSocket) String clientDN = CertificateUtils.extractPeerDNFromSSLSocket(mockSocket)
logger.info("Extracted client DN: ${clientDN}") logger.info("Extracted client DN: ${clientDN}")
// Assert // Assert
@ -382,11 +441,15 @@ class CertificateUtilsTest extends GroovyTestCase {
} }
@Test @Test
void testShouldHandleFailureToExtractClientCertificatesFromSSLSocketWithClientAuthNeed() { void testShouldHandleFailureToExtractClientCertificatesFromSSLClientSocketWithClientAuthNeed() {
// Arrange // Arrange
SSLSession mockSession = [getPeerCertificates: { -> throw new SSLPeerUnverifiedException("peer not authenticated") }] as SSLSession SSLSession mockSession = [getPeerCertificates: { ->
throw new SSLPeerUnverifiedException("peer not authenticated")
}] as SSLSession
// This socket is in server mode, so the peer ("target") is a client
SSLSocket mockSocket = [ SSLSocket mockSocket = [
getUseClientMode : { -> false },
getNeedClientAuth: { -> true }, getNeedClientAuth: { -> true },
getWantClientAuth: { -> false }, getWantClientAuth: { -> false },
getSession : { -> mockSession } getSession : { -> mockSession }
@ -394,7 +457,7 @@ class CertificateUtilsTest extends GroovyTestCase {
// Act // Act
def msg = shouldFail(CertificateException) { def msg = shouldFail(CertificateException) {
String clientDN = CertificateUtils.extractClientDNFromSSLSocket(mockSocket) String clientDN = CertificateUtils.extractPeerDNFromSSLSocket(mockSocket)
logger.info("Extracted client DN: ${clientDN}") logger.info("Extracted client DN: ${clientDN}")
} }

View File

@ -95,7 +95,7 @@ public abstract class AbstractNodeProtocolSender implements NodeProtocolSender {
private String getCoordinatorDN(Socket socket) { private String getCoordinatorDN(Socket socket) {
try { try {
return CertificateUtils.extractClientDNFromSSLSocket(socket); return CertificateUtils.extractPeerDNFromSSLSocket(socket);
} catch (CertificateException e) { } catch (CertificateException e) {
throw new ProtocolException(e); throw new ProtocolException(e);
} }

View File

@ -187,7 +187,7 @@ public class SocketProtocolListener extends SocketListener implements ProtocolLi
private String getRequestorDN(Socket socket) { private String getRequestorDN(Socket socket) {
try { try {
return CertificateUtils.extractClientDNFromSSLSocket(socket); return CertificateUtils.extractPeerDNFromSSLSocket(socket);
} catch (CertificateException e) { } catch (CertificateException e) {
throw new ProtocolException(e); throw new ProtocolException(e);
} }

View File

@ -48,8 +48,9 @@ public abstract class NiFiAuthenticationFilter extends GenericFilterBean {
@Override @Override
public void doFilter(final ServletRequest request, final ServletResponse response, final FilterChain chain) throws IOException, ServletException { public void doFilter(final ServletRequest request, final ServletResponse response, final FilterChain chain) throws IOException, ServletException {
final Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
if (log.isDebugEnabled()) { if (log.isDebugEnabled()) {
log.debug("Checking secure context token: " + SecurityContextHolder.getContext().getAuthentication()); log.debug("Checking secure context token: " + authentication);
} }
if (requiresAuthentication((HttpServletRequest) request)) { if (requiresAuthentication((HttpServletRequest) request)) {