diff --git a/shield/src/main/java/org/elasticsearch/shield/authc/support/CachingUsernamePasswordRealm.java b/shield/src/main/java/org/elasticsearch/shield/authc/support/CachingUsernamePasswordRealm.java index 9e36ff79ef9..cb4236a9d3e 100644 --- a/shield/src/main/java/org/elasticsearch/shield/authc/support/CachingUsernamePasswordRealm.java +++ b/shield/src/main/java/org/elasticsearch/shield/authc/support/CachingUsernamePasswordRealm.java @@ -68,19 +68,25 @@ public abstract class CachingUsernamePasswordRealm extends UsernamePasswordRealm return doAuthenticate(token); } - CacheLoader callback = key -> { - if (logger.isDebugEnabled()) { - logger.debug("user not found in cache, proceeding with normal authentication"); - } - User user = doAuthenticate(token); - if (user == null) { - throw Exceptions.authenticationError("could not authenticate [{}]", token.principal()); - } - return new UserWithHash(user, token.credentials(), hasher); - }; - try { - UserWithHash userWithHash = cache.computeIfAbsent(token.principal(), callback); + UserWithHash userWithHash = cache.get(token.principal()); + if (userWithHash == null) { + if (logger.isDebugEnabled()) { + logger.debug("user not found in cache, proceeding with normal authentication"); + } + User user = doAuthenticate(token); + if (user == null) { + return null; + } + userWithHash = new UserWithHash(user, token.credentials(), hasher); + // it doesn't matter if we already computed it elsewhere + cache.put(token.principal(), userWithHash); + if (logger.isDebugEnabled()) { + logger.debug("authenticated user [{}], with roles [{}]", token.principal(), user.roles()); + } + return user; + } + final boolean hadHash = userWithHash.hasHash(); if (hadHash) { if (userWithHash.verify(token.credentials())) { @@ -91,9 +97,14 @@ public abstract class CachingUsernamePasswordRealm extends UsernamePasswordRealm } } //this handles when a user's password has changed or the user was looked up for run as and not authenticated - expire(token.principal()); - userWithHash = cache.computeIfAbsent(token.principal(), callback); - + cache.invalidate(token.principal()); + User user = doAuthenticate(token); + if (user == null) { + return null; + } + userWithHash = new UserWithHash(user, token.credentials(), hasher); + // it doesn't matter if we already computed it elsewhere + cache.put(token.principal(), userWithHash); if (logger.isDebugEnabled()) { if (hadHash) { logger.debug("cached user's password changed. authenticated user [{}], with roles [{}]", token.principal(), userWithHash.user.roles()); @@ -103,7 +114,7 @@ public abstract class CachingUsernamePasswordRealm extends UsernamePasswordRealm } return userWithHash.user; - } catch (ExecutionException ee) { + } catch (Exception ee) { if (logger.isTraceEnabled()) { logger.trace("realm [" + type() + "] could not authenticate [" + token.principal() + "]", ee); } else if (logger.isDebugEnabled()) { diff --git a/shield/src/test/java/org/elasticsearch/shield/authc/support/CachingUsernamePasswordRealmTests.java b/shield/src/test/java/org/elasticsearch/shield/authc/support/CachingUsernamePasswordRealmTests.java index 374f95f014c..63e55e264da 100644 --- a/shield/src/test/java/org/elasticsearch/shield/authc/support/CachingUsernamePasswordRealmTests.java +++ b/shield/src/test/java/org/elasticsearch/shield/authc/support/CachingUsernamePasswordRealmTests.java @@ -13,6 +13,9 @@ import org.elasticsearch.shield.authc.RealmConfig; import org.elasticsearch.test.ESTestCase; import org.junit.Before; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicInteger; import static org.hamcrest.Matchers.arrayContaining; @@ -170,6 +173,67 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase { assertThat(realm.lookupInvocationCounter.intValue(), is(0)); } + public void testCacheConcurrency() throws Exception { + final String username = "username"; + final SecuredString password = new SecuredString("changeme".toCharArray()); + final SecuredString randomPassword = new SecuredString(randomAsciiOfLength(password.length()).toCharArray()); + + final String passwordHash = new String(Hasher.BCRYPT.hash(password)); + RealmConfig config = new RealmConfig("test_realm", Settings.EMPTY, globalSettings); + final CachingUsernamePasswordRealm realm = new CachingUsernamePasswordRealm("test", config) { + @Override + protected User doAuthenticate(UsernamePasswordToken token) { + // do something slow + if (BCrypt.checkpw(token.credentials(), passwordHash)) { + return new User.Simple(username, new String[]{"r1", "r2", "r3"}); + } + return null; + } + + @Override + protected User doLookupUser(String username) { + throw new UnsupportedOperationException("this method should not be called"); + } + + @Override + public boolean userLookupSupported() { + return false; + } + }; + + final CountDownLatch latch = new CountDownLatch(1); + final int numberOfThreads = randomIntBetween(8, 24); + List threads = new ArrayList<>(); + for (int i = 0; i < numberOfThreads; i++) { + final boolean invalidPassword = randomBoolean(); + threads.add(new Thread() { + @Override + public void run() { + try { + latch.await(); + for (int i = 0; i < 100; i++) { + User user = realm.authenticate(new UsernamePasswordToken(username, invalidPassword ? randomPassword : password)); + if (invalidPassword && user != null) { + throw new RuntimeException("invalid password led to an authenticated user: " + user.toString()); + } else if (invalidPassword == false && user == null) { + throw new RuntimeException("proper password led to a null user!"); + } + } + + } catch (InterruptedException e) {} + } + }); + } + + for (Thread thread : threads) { + thread.start(); + } + latch.countDown(); + for (Thread thread : threads) { + thread.join(); + } + } + static class FailingAuthenticationRealm extends CachingUsernamePasswordRealm { FailingAuthenticationRealm(Settings settings, Settings global) {