diff --git a/extensions-core/druid-basic-security/src/main/java/org/apache/druid/security/basic/authentication/validator/LDAPCredentialsValidator.java b/extensions-core/druid-basic-security/src/main/java/org/apache/druid/security/basic/authentication/validator/LDAPCredentialsValidator.java index 1db8799de38..9665f52fde8 100644 --- a/extensions-core/druid-basic-security/src/main/java/org/apache/druid/security/basic/authentication/validator/LDAPCredentialsValidator.java +++ b/extensions-core/druid-basic-security/src/main/java/org/apache/druid/security/basic/authentication/validator/LDAPCredentialsValidator.java @@ -22,6 +22,7 @@ package org.apache.druid.security.basic.authentication.validator; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonTypeName; +import com.google.common.annotations.VisibleForTesting; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.metadata.PasswordProvider; @@ -43,6 +44,7 @@ import javax.naming.directory.InitialDirContext; import javax.naming.directory.SearchControls; import javax.naming.directory.SearchResult; import javax.naming.ldap.LdapName; + import java.util.HashMap; import java.util.LinkedHashMap; import java.util.Map; @@ -58,6 +60,9 @@ public class LDAPCredentialsValidator implements CredentialsValidator private final LruBlockCache cache; private final BasicAuthLDAPConfig ldapConfig; + // Custom overrides that can be passed via tests + @Nullable + private final Properties overrideProperties; @JsonCreator public LDAPCredentialsValidator( @@ -91,6 +96,19 @@ public class LDAPCredentialsValidator implements CredentialsValidator this.ldapConfig.getCredentialVerifyDuration(), this.ldapConfig.getCredentialMaxDuration() ); + this.overrideProperties = null; + } + + @VisibleForTesting + public LDAPCredentialsValidator( + final BasicAuthLDAPConfig ldapConfig, + final LruBlockCache cache, + final Properties overrideProperties + ) + { + this.ldapConfig = ldapConfig; + this.cache = cache; + this.overrideProperties = overrideProperties; } Properties bindProperties(BasicAuthLDAPConfig ldapConfig) @@ -119,6 +137,9 @@ public class LDAPCredentialsValidator implements CredentialsValidator properties.put(Context.SECURITY_PROTOCOL, "ssl"); properties.put("java.naming.ldap.factory.socket", BasicSecuritySSLSocketFactory.class.getName()); } + if (null != overrideProperties) { + properties.putAll(overrideProperties); + } return properties; } @@ -139,7 +160,11 @@ public class LDAPCredentialsValidator implements CredentialsValidator contextMap.put(BasicAuthUtils.SEARCH_RESULT_CONTEXT_KEY, principal.getSearchResult()); return new AuthenticationResult(username, authorizerName, authenticatorName, contextMap); } else { + ClassLoader currentClassLoader = Thread.currentThread().getContextClassLoader(); try { + // Set the context classloader same as the loader of this class so that BasicSecuritySSLSocketFactory + // class can be found + Thread.currentThread().setContextClassLoader(this.getClass().getClassLoader()); InitialDirContext dirContext = new InitialDirContext(bindProperties(this.ldapConfig)); try { userResult = getLdapUserObject(this.ldapConfig, dirContext, username); @@ -162,6 +187,9 @@ public class LDAPCredentialsValidator implements CredentialsValidator LOG.error(e, "Exception during user lookup"); return null; } + finally { + Thread.currentThread().setContextClassLoader(currentClassLoader); + } if (!validatePassword(this.ldapConfig, userDn, password)) { LOG.debug("Password incorrect for LDAP user %s", username); @@ -213,8 +241,10 @@ public class LDAPCredentialsValidator implements CredentialsValidator boolean validatePassword(BasicAuthLDAPConfig ldapConfig, LdapName userDn, char[] password) { InitialDirContext context = null; + ClassLoader currentClassLoader = Thread.currentThread().getContextClassLoader(); try { + Thread.currentThread().setContextClassLoader(this.getClass().getClassLoader()); context = new InitialDirContext(userProperties(ldapConfig, userDn, password)); return true; } @@ -235,10 +265,11 @@ public class LDAPCredentialsValidator implements CredentialsValidator LOG.warn("Exception closing LDAP context"); // ignored } + Thread.currentThread().setContextClassLoader(currentClassLoader); } } - private static class LruBlockCache extends LinkedHashMap + public static class LruBlockCache extends LinkedHashMap { private static final long serialVersionUID = 7509410739092012261L; diff --git a/extensions-core/druid-basic-security/src/test/java/org/apache/druid/security/authentication/validator/LDAPCredentialsValidatorTest.java b/extensions-core/druid-basic-security/src/test/java/org/apache/druid/security/authentication/validator/LDAPCredentialsValidatorTest.java index 56aee168c96..aabdea31b84 100644 --- a/extensions-core/druid-basic-security/src/test/java/org/apache/druid/security/authentication/validator/LDAPCredentialsValidatorTest.java +++ b/extensions-core/druid-basic-security/src/test/java/org/apache/druid/security/authentication/validator/LDAPCredentialsValidatorTest.java @@ -19,12 +19,43 @@ package org.apache.druid.security.authentication.validator; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.metadata.DefaultPasswordProvider; +import org.apache.druid.security.basic.BasicAuthLDAPConfig; +import org.apache.druid.security.basic.BasicAuthUtils; import org.apache.druid.security.basic.authentication.validator.LDAPCredentialsValidator; import org.junit.Assert; import org.junit.Test; +import org.mockito.ArgumentMatchers; +import org.mockito.Mockito; + +import javax.naming.Context; +import javax.naming.NamingEnumeration; +import javax.naming.NamingException; +import javax.naming.directory.SearchControls; +import javax.naming.directory.SearchResult; +import javax.naming.ldap.LdapContext; +import javax.naming.spi.InitialContextFactory; + +import java.util.Collections; +import java.util.Hashtable; +import java.util.Iterator; +import java.util.Properties; public class LDAPCredentialsValidatorTest { + private static final BasicAuthLDAPConfig LDAP_CONFIG = new BasicAuthLDAPConfig( + "ldaps://my-ldap-url", + "bindUser", + new DefaultPasswordProvider("bindPassword"), + "", + "", + "", + BasicAuthUtils.DEFAULT_KEY_ITERATIONS, + BasicAuthUtils.DEFAULT_CREDENTIAL_VERIFY_DURATION_SECONDS, + BasicAuthUtils.DEFAULT_CREDENTIAL_MAX_DURATION_SECONDS, + BasicAuthUtils.DEFAULT_CREDENTIAL_CACHE_SIZE); + @Test public void testEncodeForLDAP_noSpecialChars() { @@ -44,4 +75,80 @@ public class LDAPCredentialsValidatorTest Assert.assertEquals(expectedWildcardTrue, encodedWildcardTrue); Assert.assertEquals(expectedWildcardFalse, encodedWildcardFalse); } + + /** + * This doesn't test password validation. + */ + @Test + public void testValidateCredentials() + { + Properties properties = new Properties(); + properties.put(Context.INITIAL_CONTEXT_FACTORY, MockContextFactory.class.getName()); + LDAPCredentialsValidator validator = new LDAPCredentialsValidator( + LDAP_CONFIG, + new LDAPCredentialsValidator.LruBlockCache( + 3600, + 3600, + 100 + ), + properties + ); + validator.validateCredentials("ldap", "ldap", "validUser", "password".toCharArray()); + } + + public static class MockContextFactory implements InitialContextFactory + { + @Override + public Context getInitialContext(Hashtable environment) throws NamingException + { + LdapContext context = Mockito.mock(LdapContext.class); + + String encodedUsername = LDAPCredentialsValidator.encodeForLDAP("validUser", true); + SearchResult result = Mockito.mock(SearchResult.class); + Mockito.when(result.getNameInNamespace()).thenReturn("uid=user,ou=Users,dc=example,dc=org"); + Iterator results = Collections.singletonList(result).iterator(); + + Mockito.when( + context.search( + ArgumentMatchers.eq(LDAP_CONFIG.getBaseDn()), + ArgumentMatchers.eq(StringUtils.format(LDAP_CONFIG.getUserSearch(), encodedUsername)), + ArgumentMatchers.any(SearchControls.class)) + ).thenReturn(new NamingEnumeration() + { + @Override + public SearchResult next() + { + return results.next(); + } + + @Override + public boolean hasMore() + { + return results.hasNext(); + } + + @Override + public void close() + { + // No-op + } + + @Override + public boolean hasMoreElements() + { + return results.hasNext(); + } + + @Override + public SearchResult nextElement() + { + return results.next(); + } + }); + + return context; + } + } + + }