Security: propagate auth result to listeners (#36900)

After #30794, our caching realms limit each principal to a single auth
attempt at a time. This prevents hammering of external servers but can
cause a significant performance hit when requests need to go through a
realm that takes a long time to attempt to authenticate in order to get
to the realm that actually authenticates. In order to address this,
this change will propagate failed results to listeners if they use the
same set of credentials that the authentication attempt used. This does
prevent these stalled requests from retrying the authentication attempt
but the implementation does allow for new requests to retry the
attempt.
This commit is contained in:
Jay Modi 2019-01-08 08:52:12 -07:00 committed by GitHub
parent dd69553d4d
commit 1514bbcdde
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 131 additions and 84 deletions

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.security.authc.support; package org.elasticsearch.xpack.security.authc.support;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.cache.Cache; import org.elasticsearch.common.cache.Cache;
import org.elasticsearch.common.cache.CacheBuilder; import org.elasticsearch.common.cache.CacheBuilder;
import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.SecureString;
@ -29,7 +30,7 @@ import java.util.concurrent.atomic.AtomicBoolean;
public abstract class CachingUsernamePasswordRealm extends UsernamePasswordRealm implements CachingRealm { public abstract class CachingUsernamePasswordRealm extends UsernamePasswordRealm implements CachingRealm {
private final Cache<String, ListenableFuture<UserWithHash>> cache; private final Cache<String, ListenableFuture<CachedResult>> cache;
private final ThreadPool threadPool; private final ThreadPool threadPool;
private final boolean authenticationEnabled; private final boolean authenticationEnabled;
final Hasher cacheHasher; final Hasher cacheHasher;
@ -40,7 +41,7 @@ public abstract class CachingUsernamePasswordRealm extends UsernamePasswordRealm
this.threadPool = threadPool; this.threadPool = threadPool;
final TimeValue ttl = this.config.getSetting(CachingUsernamePasswordRealmSettings.CACHE_TTL_SETTING); final TimeValue ttl = this.config.getSetting(CachingUsernamePasswordRealmSettings.CACHE_TTL_SETTING);
if (ttl.getNanos() > 0) { if (ttl.getNanos() > 0) {
cache = CacheBuilder.<String, ListenableFuture<UserWithHash>>builder() cache = CacheBuilder.<String, ListenableFuture<CachedResult>>builder()
.setExpireAfterWrite(ttl) .setExpireAfterWrite(ttl)
.setMaximumWeight(this.config.getSetting(CachingUsernamePasswordRealmSettings.CACHE_MAX_USERS_SETTING)) .setMaximumWeight(this.config.getSetting(CachingUsernamePasswordRealmSettings.CACHE_MAX_USERS_SETTING))
.build(); .build();
@ -122,16 +123,18 @@ public abstract class CachingUsernamePasswordRealm extends UsernamePasswordRealm
assert cache != null; assert cache != null;
try { try {
final AtomicBoolean authenticationInCache = new AtomicBoolean(true); final AtomicBoolean authenticationInCache = new AtomicBoolean(true);
final ListenableFuture<UserWithHash> listenableCacheEntry = cache.computeIfAbsent(token.principal(), k -> { final ListenableFuture<CachedResult> listenableCacheEntry = cache.computeIfAbsent(token.principal(), k -> {
authenticationInCache.set(false); authenticationInCache.set(false);
return new ListenableFuture<>(); return new ListenableFuture<>();
}); });
if (authenticationInCache.get()) { if (authenticationInCache.get()) {
// there is a cached or an inflight authenticate request // there is a cached or an inflight authenticate request
listenableCacheEntry.addListener(ActionListener.wrap(authenticatedUserWithHash -> { listenableCacheEntry.addListener(ActionListener.wrap(cachedResult -> {
if (authenticatedUserWithHash != null && authenticatedUserWithHash.verify(token.credentials())) { final boolean credsMatch = cachedResult.verify(token.credentials());
if (cachedResult.authenticationResult.isAuthenticated()) {
if (credsMatch) {
// cached credential hash matches the credential hash for this forestalled request // cached credential hash matches the credential hash for this forestalled request
handleCachedAuthentication(authenticatedUserWithHash.user, ActionListener.wrap(cacheResult -> { handleCachedAuthentication(cachedResult.user, ActionListener.wrap(cacheResult -> {
if (cacheResult.isAuthenticated()) { if (cacheResult.isAuthenticated()) {
logger.debug("realm [{}] authenticated user [{}], with roles [{}]", logger.debug("realm [{}] authenticated user [{}], with roles [{}]",
name(), token.principal(), cacheResult.getUser().roles()); name(), token.principal(), cacheResult.getUser().roles());
@ -142,38 +145,39 @@ public abstract class CachingUsernamePasswordRealm extends UsernamePasswordRealm
listener.onResponse(cacheResult); listener.onResponse(cacheResult);
}, listener::onFailure)); }, listener::onFailure));
} else { } else {
// The inflight request has failed or its credential hash does not match the // its credential hash does not match the
// hash of the credential for this forestalled request. // hash of the credential for this forestalled request.
// clear cache and try to reach the authentication source again because password // clear cache and try to reach the authentication source again because password
// might have changed there and the local cached hash got stale // might have changed there and the local cached hash got stale
cache.invalidate(token.principal(), listenableCacheEntry); cache.invalidate(token.principal(), listenableCacheEntry);
authenticateWithCache(token, listener); authenticateWithCache(token, listener);
} }
}, e -> { } else if (credsMatch) {
// the inflight request failed, so try again, but first (always) make sure cache // not authenticated but instead of hammering reuse the result. a new
// is cleared of the failed authentication // request will trigger a retried auth
listener.onResponse(cachedResult.authenticationResult);
} else {
cache.invalidate(token.principal(), listenableCacheEntry); cache.invalidate(token.principal(), listenableCacheEntry);
authenticateWithCache(token, listener); authenticateWithCache(token, listener);
}), threadPool.executor(ThreadPool.Names.GENERIC), threadPool.getThreadContext()); }
}, listener::onFailure), threadPool.executor(ThreadPool.Names.GENERIC), threadPool.getThreadContext());
} else { } else {
// attempt authentication against the authentication source // attempt authentication against the authentication source
doAuthenticate(token, ActionListener.wrap(authResult -> { doAuthenticate(token, ActionListener.wrap(authResult -> {
if (authResult.isAuthenticated() && authResult.getUser().enabled()) { if (authResult.isAuthenticated() == false || authResult.getUser().enabled() == false) {
// compute the credential hash of this successful authentication request // a new request should trigger a new authentication
final UserWithHash userWithHash = new UserWithHash(authResult.getUser(), token.credentials(), cacheHasher); cache.invalidate(token.principal(), listenableCacheEntry);
// notify any forestalled request listeners; they will not reach to the
// authentication request and instead will use this hash for comparison
listenableCacheEntry.onResponse(userWithHash);
} else {
// notify any forestalled request listeners; they will retry the request
listenableCacheEntry.onResponse(null);
} }
// notify the listener of the inflight authentication request; this request is not retried // notify any forestalled request listeners; they will not reach to the
// authentication request and instead will use this result if they contain
// the same credentials
listenableCacheEntry.onResponse(new CachedResult(authResult, cacheHasher, authResult.getUser(), token.credentials()));
listener.onResponse(authResult); listener.onResponse(authResult);
}, e -> { }, e -> {
// notify any staved off listeners; they will retry the request cache.invalidate(token.principal(), listenableCacheEntry);
// notify any staved off listeners; they will propagate this error
listenableCacheEntry.onFailure(e); listenableCacheEntry.onFailure(e);
// notify the listener of the inflight authentication request; this request is not retried // notify the listener of the inflight authentication request
listener.onFailure(e); listener.onFailure(e);
})); }));
} }
@ -225,25 +229,21 @@ public abstract class CachingUsernamePasswordRealm extends UsernamePasswordRealm
assert cache != null; assert cache != null;
try { try {
final AtomicBoolean lookupInCache = new AtomicBoolean(true); final AtomicBoolean lookupInCache = new AtomicBoolean(true);
final ListenableFuture<UserWithHash> listenableCacheEntry = cache.computeIfAbsent(username, key -> { final ListenableFuture<CachedResult> listenableCacheEntry = cache.computeIfAbsent(username, key -> {
lookupInCache.set(false); lookupInCache.set(false);
return new ListenableFuture<>(); return new ListenableFuture<>();
}); });
if (false == lookupInCache.get()) { if (false == lookupInCache.get()) {
// attempt lookup against the user directory // attempt lookup against the user directory
doLookupUser(username, ActionListener.wrap(user -> { doLookupUser(username, ActionListener.wrap(user -> {
if (user != null) { final CachedResult result = new CachedResult(AuthenticationResult.notHandled(), cacheHasher, user, null);
// user found if (user == null) {
final UserWithHash userWithHash = new UserWithHash(user, null, null);
// notify forestalled request listeners
listenableCacheEntry.onResponse(userWithHash);
} else {
// user not found, invalidate cache so that subsequent requests are forwarded to // user not found, invalidate cache so that subsequent requests are forwarded to
// the user directory // the user directory
cache.invalidate(username, listenableCacheEntry); cache.invalidate(username, listenableCacheEntry);
// notify forestalled request listeners
listenableCacheEntry.onResponse(null);
} }
// notify forestalled request listeners
listenableCacheEntry.onResponse(result);
}, e -> { }, e -> {
// the next request should be forwarded, not halted by a failed lookup attempt // the next request should be forwarded, not halted by a failed lookup attempt
cache.invalidate(username, listenableCacheEntry); cache.invalidate(username, listenableCacheEntry);
@ -251,9 +251,9 @@ public abstract class CachingUsernamePasswordRealm extends UsernamePasswordRealm
listenableCacheEntry.onFailure(e); listenableCacheEntry.onFailure(e);
})); }));
} }
listenableCacheEntry.addListener(ActionListener.wrap(userWithHash -> { listenableCacheEntry.addListener(ActionListener.wrap(cachedResult -> {
if (userWithHash != null) { if (cachedResult.user != null) {
listener.onResponse(userWithHash.user); listener.onResponse(cachedResult.user);
} else { } else {
listener.onResponse(null); listener.onResponse(null);
} }
@ -265,16 +265,21 @@ public abstract class CachingUsernamePasswordRealm extends UsernamePasswordRealm
protected abstract void doLookupUser(String username, ActionListener<User> listener); protected abstract void doLookupUser(String username, ActionListener<User> listener);
private static class UserWithHash { private static class CachedResult {
final User user; private final AuthenticationResult authenticationResult;
final char[] hash; private final User user;
private final char[] hash;
UserWithHash(User user, SecureString password, Hasher hasher) { private CachedResult(AuthenticationResult result, Hasher hasher, @Nullable User user, @Nullable SecureString password) {
this.user = Objects.requireNonNull(user); this.authenticationResult = Objects.requireNonNull(result);
if (authenticationResult.isAuthenticated() && user == null) {
throw new IllegalArgumentException("authentication cannot be successful with a null user");
}
this.user = user;
this.hash = password == null ? null : hasher.hash(password); this.hash = password == null ? null : hasher.hash(password);
} }
boolean verify(SecureString password) { private boolean verify(SecureString password) {
return hash != null && Hasher.verifyHash(password, hash); return hash != null && Hasher.verifyHash(password, hash);
} }
} }

View File

@ -58,13 +58,13 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase {
} }
@After @After
public void stop() throws InterruptedException { public void stop() {
if (threadPool != null) { if (threadPool != null) {
terminate(threadPool); terminate(threadPool);
} }
} }
public void testCacheSettings() throws Exception { public void testCacheSettings() {
String cachingHashAlgo = Hasher.values()[randomIntBetween(0, Hasher.values().length - 1)].name().toLowerCase(Locale.ROOT); String cachingHashAlgo = Hasher.values()[randomIntBetween(0, Hasher.values().length - 1)].name().toLowerCase(Locale.ROOT);
int maxUsers = randomIntBetween(10, 100); int maxUsers = randomIntBetween(10, 100);
TimeValue ttl = TimeValue.timeValueMinutes(randomIntBetween(10, 20)); TimeValue ttl = TimeValue.timeValueMinutes(randomIntBetween(10, 20));
@ -352,7 +352,7 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase {
} }
} }
public void testAuthenticateContract() throws Exception { public void testAuthenticateContract() {
Realm realm = new FailingAuthenticationRealm(globalSettings, threadPool); Realm realm = new FailingAuthenticationRealm(globalSettings, threadPool);
PlainActionFuture<AuthenticationResult> future = new PlainActionFuture<>(); PlainActionFuture<AuthenticationResult> future = new PlainActionFuture<>();
realm.authenticate(new UsernamePasswordToken("user", new SecureString("pass")), future); realm.authenticate(new UsernamePasswordToken("user", new SecureString("pass")), future);
@ -366,7 +366,7 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase {
assertThat(e.getMessage(), containsString("whatever exception")); assertThat(e.getMessage(), containsString("whatever exception"));
} }
public void testLookupContract() throws Exception { public void testLookupContract() {
Realm realm = new FailingAuthenticationRealm(globalSettings, threadPool); Realm realm = new FailingAuthenticationRealm(globalSettings, threadPool);
PlainActionFuture<User> future = new PlainActionFuture<>(); PlainActionFuture<User> future = new PlainActionFuture<>();
realm.lookupUser("user", future); realm.lookupUser("user", future);
@ -380,7 +380,7 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase {
assertThat(e.getMessage(), containsString("lookup exception")); assertThat(e.getMessage(), containsString("lookup exception"));
} }
public void testReturnDifferentObjectFromCache() throws Exception { public void testReturnDifferentObjectFromCache() {
final AtomicReference<User> userArg = new AtomicReference<>(); final AtomicReference<User> userArg = new AtomicReference<>();
final AtomicReference<AuthenticationResult> result = new AtomicReference<>(); final AtomicReference<AuthenticationResult> result = new AtomicReference<>();
Realm realm = new AlwaysAuthenticateCachingRealm(globalSettings, threadPool) { Realm realm = new AlwaysAuthenticateCachingRealm(globalSettings, threadPool) {
@ -473,6 +473,71 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase {
assertEquals(1, authCounter.get()); assertEquals(1, authCounter.get());
} }
public void testUnauthenticatedResultPropagatesWithSameCreds() throws Exception {
final String username = "username";
final SecureString password = SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING;
final AtomicInteger authCounter = new AtomicInteger(0);
final Hasher pwdHasher = Hasher.resolve(randomFrom("pbkdf2", "pbkdf2_1000", "bcrypt", "bcrypt9"));
final String passwordHash = new String(pwdHasher.hash(password));
RealmConfig config = new RealmConfig(new RealmConfig.RealmIdentifier("caching", "test_realm"), globalSettings,
TestEnvironment.newEnvironment(globalSettings), new ThreadContext(Settings.EMPTY));
final int numberOfProcessors = Runtime.getRuntime().availableProcessors();
final int numberOfThreads = scaledRandomIntBetween((numberOfProcessors + 1) / 2, numberOfProcessors * 3);
final CountDownLatch latch = new CountDownLatch(1 + numberOfThreads);
List<Thread> threads = new ArrayList<>(numberOfThreads);
final SecureString credsToUse = new SecureString(randomAlphaOfLength(12).toCharArray());
final CachingUsernamePasswordRealm realm = new CachingUsernamePasswordRealm(config, threadPool) {
@Override
protected void doAuthenticate(UsernamePasswordToken token, ActionListener<AuthenticationResult> listener) {
authCounter.incrementAndGet();
// do something slow
if (pwdHasher.verify(token.credentials(), passwordHash.toCharArray())) {
listener.onFailure(new IllegalStateException("password auth should never succeed"));
} else {
listener.onResponse(AuthenticationResult.unsuccessful("password verification failed", null));
}
}
@Override
protected void doLookupUser(String username, ActionListener<User> listener) {
listener.onFailure(new UnsupportedOperationException("this method should not be called"));
}
};
for (int i = 0; i < numberOfThreads; i++) {
threads.add(new Thread(() -> {
try {
latch.countDown();
latch.await();
final UsernamePasswordToken token = new UsernamePasswordToken(username, credsToUse);
realm.authenticate(token, ActionListener.wrap((result) -> {
if (result.isAuthenticated()) {
throw new IllegalStateException("invalid password led to an authenticated result: " + result);
}
assertThat(result.getMessage(), containsString("password verification failed"));
}, (e) -> {
logger.error("caught exception", e);
fail("unexpected exception - " + e);
}));
} catch (InterruptedException e) {
logger.error("thread was interrupted", e);
Thread.currentThread().interrupt();
}
}));
}
for (Thread thread : threads) {
thread.start();
}
latch.countDown();
for (Thread thread : threads) {
thread.join();
}
assertEquals(1, authCounter.get());
}
public void testCacheConcurrency() throws Exception { public void testCacheConcurrency() throws Exception {
final String username = "username"; final String username = "username";
final SecureString password = SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING; final SecureString password = SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING;
@ -704,27 +769,4 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase {
listener.onResponse(new User(username, new String[]{"lookupRole1", "lookupRole2"})); listener.onResponse(new User(username, new String[]{"lookupRole1", "lookupRole2"}));
} }
} }
static class LookupNotSupportedRealm extends CachingUsernamePasswordRealm {
public final AtomicInteger authInvocationCounter = new AtomicInteger(0);
public final AtomicInteger lookupInvocationCounter = new AtomicInteger(0);
LookupNotSupportedRealm(Settings globalSettings, ThreadPool threadPool) {
super(new RealmConfig(new RealmConfig.RealmIdentifier("caching", "lookup-notsupported-test"), globalSettings,
TestEnvironment.newEnvironment(globalSettings), threadPool.getThreadContext()), threadPool);
}
@Override
protected void doAuthenticate(UsernamePasswordToken token, ActionListener<AuthenticationResult> listener) {
authInvocationCounter.incrementAndGet();
listener.onResponse(AuthenticationResult.success(new User(token.principal(), new String[]{"testRole1", "testRole2"})));
}
@Override
protected void doLookupUser(String username, ActionListener<User> listener) {
lookupInvocationCounter.incrementAndGet();
listener.onFailure(new UnsupportedOperationException("don't call lookup if lookup isn't supported!!!"));
}
}
} }