mirror of https://github.com/apache/nifi.git
NIFI-2186 Refactored CertificateUtils to separate logic for DN extraction from server/client sockets. Added logic to detect server/client mode encapsulated in exposed method.
Added unit tests for DN extraction. Corrected typo in Javadoc. Switched server/client socket logic for certificate extraction -- when the local socket is in client/server mode, the peer is necessarily the inverse. Fixed unit tests. Moved lazy-loading authentication access out of isDebugEnabled() control branch. This closes #622
This commit is contained in:
parent
039fd70ded
commit
4b9df7d1e2
|
@ -166,11 +166,51 @@ public final class CertificateUtils {
|
|||
return result;
|
||||
}
|
||||
|
||||
public static String extractClientDNFromSSLSocket(Socket socket) throws CertificateException {
|
||||
/**
|
||||
* Returns the DN extracted from the peer certificate (the server DN if run on the client; the client DN (if available) if run on the server).
|
||||
*
|
||||
* If the client auth setting is WANT or NONE and a client certificate is not present, this method will return {@code null}.
|
||||
* If the client auth is NEED, it will throw a {@link CertificateException}.
|
||||
*
|
||||
* @param socket the SSL Socket
|
||||
* @return the extracted DN
|
||||
* @throws CertificateException if there is a problem parsing the certificate
|
||||
*/
|
||||
public static String extractPeerDNFromSSLSocket(Socket socket) throws CertificateException {
|
||||
String dn = null;
|
||||
if (socket instanceof SSLSocket) {
|
||||
final SSLSocket sslSocket = (SSLSocket) socket;
|
||||
|
||||
boolean clientMode = sslSocket.getUseClientMode();
|
||||
logger.debug("SSL Socket in {} mode", clientMode ? "client" : "server");
|
||||
ClientAuth clientAuth = getClientAuthStatus(sslSocket);
|
||||
logger.debug("SSL Socket client auth status: {}", clientAuth);
|
||||
|
||||
if (clientMode) {
|
||||
logger.debug("This socket is in client mode, so attempting to extract certificate from remote 'server' socket");
|
||||
dn = extractPeerDNFromServerSSLSocket(sslSocket);
|
||||
} else {
|
||||
logger.debug("This socket is in server mode, so attempting to extract certificate from remote 'client' socket");
|
||||
dn = extractPeerDNFromClientSSLSocket(sslSocket);
|
||||
}
|
||||
}
|
||||
|
||||
return dn;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the DN extracted from the client certificate.
|
||||
*
|
||||
* If the client auth setting is WANT or NONE and a certificate is not present (and {@code respectClientAuth} is {@code true}), this method will return {@code null}.
|
||||
* If the client auth is NEED, it will throw a {@link CertificateException}.
|
||||
*
|
||||
* @param sslSocket the SSL Socket
|
||||
* @return the extracted DN
|
||||
* @throws CertificateException if there is a problem parsing the certificate
|
||||
*/
|
||||
private static String extractPeerDNFromClientSSLSocket(SSLSocket sslSocket) throws CertificateException {
|
||||
String dn = null;
|
||||
|
||||
/** The clientAuth value can be "need", "want", or "none"
|
||||
* A client must send client certificates for need, should for want, and will not for none.
|
||||
* This method should throw an exception if none are provided for need, return null if none are provided for want, and return null (without checking) for none.
|
||||
|
@ -185,6 +225,7 @@ public final class CertificateUtils {
|
|||
if (certChains != null && certChains.length > 0) {
|
||||
X509Certificate x509Certificate = convertAbstractX509Certificate(certChains[0]);
|
||||
dn = x509Certificate.getSubjectDN().getName().trim();
|
||||
logger.debug("Extracted DN={} from client certificate", dn);
|
||||
}
|
||||
} catch (SSLPeerUnverifiedException e) {
|
||||
if (e.getMessage().equals(PEER_NOT_AUTHENTICATED_MSG)) {
|
||||
|
@ -198,8 +239,35 @@ public final class CertificateUtils {
|
|||
throw new CertificateException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
return dn;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the DN extracted from the server certificate.
|
||||
*
|
||||
* @param socket the SSL Socket
|
||||
* @return the extracted DN
|
||||
* @throws CertificateException if there is a problem parsing the certificate
|
||||
*/
|
||||
private static String extractPeerDNFromServerSSLSocket(Socket socket) throws CertificateException {
|
||||
String dn = null;
|
||||
if (socket instanceof SSLSocket) {
|
||||
final SSLSocket sslSocket = (SSLSocket) socket;
|
||||
try {
|
||||
final Certificate[] certChains = sslSocket.getSession().getPeerCertificates();
|
||||
if (certChains != null && certChains.length > 0) {
|
||||
X509Certificate x509Certificate = convertAbstractX509Certificate(certChains[0]);
|
||||
dn = x509Certificate.getSubjectDN().getName().trim();
|
||||
logger.debug("Extracted DN={} from server certificate", dn);
|
||||
}
|
||||
} catch (SSLPeerUnverifiedException e) {
|
||||
if (e.getMessage().equals(PEER_NOT_AUTHENTICATED_MSG)) {
|
||||
logger.error("The server did not present a certificate and thus the DN cannot" +
|
||||
" be extracted. Check that the other endpoint is providing a complete certificate chain");
|
||||
}
|
||||
throw new CertificateException(e);
|
||||
}
|
||||
}
|
||||
return dn;
|
||||
}
|
||||
|
||||
|
|
|
@ -297,16 +297,67 @@ class CertificateUtilsTest extends GroovyTestCase {
|
|||
assert noneClientAuthStatus == CertificateUtils.ClientAuth.NONE
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
void testShouldNotExtractClientCertificatesFromSSLSocketWithClientAuthNone() {
|
||||
void testShouldExtractClientCertificatesFromSSLServerSocketWithAnyClientAuth() {
|
||||
final String EXPECTED_DN = "CN=ncm.nifi.apache.org,OU=Security,O=Apache,ST=CA,C=US"
|
||||
Certificate[] certificateChain = generateCertificateChain(EXPECTED_DN)
|
||||
logger.info("Expected DN: ${EXPECTED_DN}")
|
||||
logger.info("Expected certificate chain: ${certificateChain.collect { (it as X509Certificate).getSubjectDN().name }.join(" issued by ")}")
|
||||
|
||||
SSLSession mockSession = [getPeerCertificates: { -> certificateChain }] as SSLSession
|
||||
|
||||
// This socket is in client mode, so the peer ("target") is a server
|
||||
|
||||
// Create mock sockets for each possible value of ClientAuth
|
||||
SSLSocket mockNoneSocket = [
|
||||
getUseClientMode : { -> true },
|
||||
getNeedClientAuth: { -> false },
|
||||
getWantClientAuth: { -> false },
|
||||
getSession : { -> mockSession }
|
||||
] as SSLSocket
|
||||
|
||||
SSLSocket mockNeedSocket = [
|
||||
getUseClientMode : { -> true },
|
||||
getNeedClientAuth: { -> true },
|
||||
getWantClientAuth: { -> false },
|
||||
getSession : { -> mockSession }
|
||||
] as SSLSocket
|
||||
|
||||
SSLSocket mockWantSocket = [
|
||||
getUseClientMode : { -> true },
|
||||
getNeedClientAuth: { -> false },
|
||||
getWantClientAuth: { -> true },
|
||||
getSession : { -> mockSession }
|
||||
] as SSLSocket
|
||||
|
||||
// Act
|
||||
def resolvedServerDNs = [mockNeedSocket, mockWantSocket, mockNoneSocket].collect { SSLSocket mockSocket ->
|
||||
logger.info("Running test with socket ClientAuth setting: ${CertificateUtils.getClientAuthStatus(mockSocket)}")
|
||||
String serverDN = CertificateUtils.extractPeerDNFromSSLSocket(mockNoneSocket)
|
||||
logger.info("Extracted server DN: ${serverDN}")
|
||||
serverDN
|
||||
}
|
||||
|
||||
// Assert
|
||||
assert resolvedServerDNs.every { String serverDN ->
|
||||
CertificateUtils.compareDNs(serverDN, EXPECTED_DN)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testShouldNotExtractClientCertificatesFromSSLClientSocketWithClientAuthNone() {
|
||||
// Arrange
|
||||
|
||||
// This socket is in server mode, so the peer ("target") is a client
|
||||
SSLSocket mockSocket = [
|
||||
getUseClientMode : { -> false },
|
||||
getNeedClientAuth: { -> false },
|
||||
getWantClientAuth: { -> false }
|
||||
] as SSLSocket
|
||||
|
||||
// Act
|
||||
String clientDN = CertificateUtils.extractClientDNFromSSLSocket(mockSocket)
|
||||
String clientDN = CertificateUtils.extractPeerDNFromSSLSocket(mockSocket)
|
||||
logger.info("Extracted client DN: ${clientDN}")
|
||||
|
||||
// Assert
|
||||
|
@ -314,7 +365,7 @@ class CertificateUtilsTest extends GroovyTestCase {
|
|||
}
|
||||
|
||||
@Test
|
||||
void testShouldExtractClientCertificatesFromSSLSocketWithClientAuthWant() {
|
||||
void testShouldExtractClientCertificatesFromSSLClientSocketWithClientAuthWant() {
|
||||
// Arrange
|
||||
final String EXPECTED_DN = "CN=client.nifi.apache.org,OU=Security,O=Apache,ST=CA,C=US"
|
||||
Certificate[] certificateChain = generateCertificateChain(EXPECTED_DN)
|
||||
|
@ -323,14 +374,16 @@ class CertificateUtilsTest extends GroovyTestCase {
|
|||
|
||||
SSLSession mockSession = [getPeerCertificates: { -> certificateChain }] as SSLSession
|
||||
|
||||
// This socket is in server mode, so the peer ("target") is a client
|
||||
SSLSocket mockSocket = [
|
||||
getUseClientMode : { -> false },
|
||||
getNeedClientAuth: { -> false },
|
||||
getWantClientAuth: { -> true },
|
||||
getSession : { -> mockSession }
|
||||
] as SSLSocket
|
||||
|
||||
// Act
|
||||
String clientDN = CertificateUtils.extractClientDNFromSSLSocket(mockSocket)
|
||||
String clientDN = CertificateUtils.extractPeerDNFromSSLSocket(mockSocket)
|
||||
logger.info("Extracted client DN: ${clientDN}")
|
||||
|
||||
// Assert
|
||||
|
@ -338,18 +391,22 @@ class CertificateUtilsTest extends GroovyTestCase {
|
|||
}
|
||||
|
||||
@Test
|
||||
void testShouldHandleFailureToExtractClientCertificatesFromSSLSocketWithClientAuthWant() {
|
||||
void testShouldHandleFailureToExtractClientCertificatesFromSSLClientSocketWithClientAuthWant() {
|
||||
// Arrange
|
||||
SSLSession mockSession = [getPeerCertificates: { -> throw new SSLPeerUnverifiedException("peer not authenticated") }] as SSLSession
|
||||
SSLSession mockSession = [getPeerCertificates: { ->
|
||||
throw new SSLPeerUnverifiedException("peer not authenticated")
|
||||
}] as SSLSession
|
||||
|
||||
// This socket is in server mode, so the peer ("target") is a client
|
||||
SSLSocket mockSocket = [
|
||||
getUseClientMode : { -> false },
|
||||
getNeedClientAuth: { -> false },
|
||||
getWantClientAuth: { -> true },
|
||||
getSession : { -> mockSession }
|
||||
] as SSLSocket
|
||||
|
||||
// Act
|
||||
String clientDN = CertificateUtils.extractClientDNFromSSLSocket(mockSocket)
|
||||
String clientDN = CertificateUtils.extractPeerDNFromSSLSocket(mockSocket)
|
||||
logger.info("Extracted client DN: ${clientDN}")
|
||||
|
||||
// Assert
|
||||
|
@ -358,7 +415,7 @@ class CertificateUtilsTest extends GroovyTestCase {
|
|||
|
||||
|
||||
@Test
|
||||
void testShouldExtractClientCertificatesFromSSLSocketWithClientAuthNeed() {
|
||||
void testShouldExtractClientCertificatesFromSSLClientSocketWithClientAuthNeed() {
|
||||
// Arrange
|
||||
final String EXPECTED_DN = "CN=client.nifi.apache.org,OU=Security,O=Apache,ST=CA,C=US"
|
||||
Certificate[] certificateChain = generateCertificateChain(EXPECTED_DN)
|
||||
|
@ -367,14 +424,16 @@ class CertificateUtilsTest extends GroovyTestCase {
|
|||
|
||||
SSLSession mockSession = [getPeerCertificates: { -> certificateChain }] as SSLSession
|
||||
|
||||
// This socket is in server mode, so the peer ("target") is a client
|
||||
SSLSocket mockSocket = [
|
||||
getUseClientMode : { -> false },
|
||||
getNeedClientAuth: { -> true },
|
||||
getWantClientAuth: { -> false },
|
||||
getSession : { -> mockSession }
|
||||
] as SSLSocket
|
||||
|
||||
// Act
|
||||
String clientDN = CertificateUtils.extractClientDNFromSSLSocket(mockSocket)
|
||||
String clientDN = CertificateUtils.extractPeerDNFromSSLSocket(mockSocket)
|
||||
logger.info("Extracted client DN: ${clientDN}")
|
||||
|
||||
// Assert
|
||||
|
@ -382,11 +441,15 @@ class CertificateUtilsTest extends GroovyTestCase {
|
|||
}
|
||||
|
||||
@Test
|
||||
void testShouldHandleFailureToExtractClientCertificatesFromSSLSocketWithClientAuthNeed() {
|
||||
void testShouldHandleFailureToExtractClientCertificatesFromSSLClientSocketWithClientAuthNeed() {
|
||||
// Arrange
|
||||
SSLSession mockSession = [getPeerCertificates: { -> throw new SSLPeerUnverifiedException("peer not authenticated") }] as SSLSession
|
||||
SSLSession mockSession = [getPeerCertificates: { ->
|
||||
throw new SSLPeerUnverifiedException("peer not authenticated")
|
||||
}] as SSLSession
|
||||
|
||||
// This socket is in server mode, so the peer ("target") is a client
|
||||
SSLSocket mockSocket = [
|
||||
getUseClientMode : { -> false },
|
||||
getNeedClientAuth: { -> true },
|
||||
getWantClientAuth: { -> false },
|
||||
getSession : { -> mockSession }
|
||||
|
@ -394,7 +457,7 @@ class CertificateUtilsTest extends GroovyTestCase {
|
|||
|
||||
// Act
|
||||
def msg = shouldFail(CertificateException) {
|
||||
String clientDN = CertificateUtils.extractClientDNFromSSLSocket(mockSocket)
|
||||
String clientDN = CertificateUtils.extractPeerDNFromSSLSocket(mockSocket)
|
||||
logger.info("Extracted client DN: ${clientDN}")
|
||||
}
|
||||
|
||||
|
|
|
@ -95,7 +95,7 @@ public abstract class AbstractNodeProtocolSender implements NodeProtocolSender {
|
|||
|
||||
private String getCoordinatorDN(Socket socket) {
|
||||
try {
|
||||
return CertificateUtils.extractClientDNFromSSLSocket(socket);
|
||||
return CertificateUtils.extractPeerDNFromSSLSocket(socket);
|
||||
} catch (CertificateException e) {
|
||||
throw new ProtocolException(e);
|
||||
}
|
||||
|
|
|
@ -187,7 +187,7 @@ public class SocketProtocolListener extends SocketListener implements ProtocolLi
|
|||
|
||||
private String getRequestorDN(Socket socket) {
|
||||
try {
|
||||
return CertificateUtils.extractClientDNFromSSLSocket(socket);
|
||||
return CertificateUtils.extractPeerDNFromSSLSocket(socket);
|
||||
} catch (CertificateException e) {
|
||||
throw new ProtocolException(e);
|
||||
}
|
||||
|
|
|
@ -48,8 +48,9 @@ public abstract class NiFiAuthenticationFilter extends GenericFilterBean {
|
|||
|
||||
@Override
|
||||
public void doFilter(final ServletRequest request, final ServletResponse response, final FilterChain chain) throws IOException, ServletException {
|
||||
final Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
|
||||
if (log.isDebugEnabled()) {
|
||||
log.debug("Checking secure context token: " + SecurityContextHolder.getContext().getAuthentication());
|
||||
log.debug("Checking secure context token: " + authentication);
|
||||
}
|
||||
|
||||
if (requiresAuthentication((HttpServletRequest) request)) {
|
||||
|
|
Loading…
Reference in New Issue