Limit user to single concurrent auth per realm (#30794)

This commit reworks the way our realms perform caching in order to
limit each principal to a single ongoing authentication per realm. In
other words, this means that multiple requests made by the same user
will not trigger more that one authentication attempt at a time if no
entry has been stored in the cache. If an entry is present in our
cache, there is no restriction on the number of concurrent
authentications performed for this user.

This change enables us to limit the load we place on an external system
like an LDAP server and also preserve resources such as CPU on
expensive operations such as BCrypt authentication.

Closes #30355
This commit is contained in:
Jay Modi 2018-05-24 10:43:10 -06:00 committed by GitHub
parent 9cb6b90a99
commit b3a4acdf20
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 712 additions and 227 deletions

View File

@ -68,6 +68,7 @@ import java.util.function.ToLongBiFunction;
* @param <V> The type of the values * @param <V> The type of the values
*/ */
public class Cache<K, V> { public class Cache<K, V> {
// positive if entries have an expiration // positive if entries have an expiration
private long expireAfterAccessNanos = -1; private long expireAfterAccessNanos = -1;
@ -282,6 +283,39 @@ public class Cache<K, V> {
} }
} }
/**
* remove an entry from the segment iff the future is done and the value is equal to the
* expected value
*
* @param key the key of the entry to remove from the cache
* @param value the value expected to be associated with the key
* @param onRemoval a callback for the removed entry
*/
void remove(K key, V value, Consumer<CompletableFuture<Entry<K, V>>> onRemoval) {
CompletableFuture<Entry<K, V>> future;
boolean removed = false;
try (ReleasableLock ignored = writeLock.acquire()) {
future = map.get(key);
try {
if (future != null) {
if (future.isDone()) {
Entry<K, V> entry = future.get();
if (Objects.equals(value, entry.value)) {
removed = map.remove(key, future);
}
}
}
} catch (ExecutionException | InterruptedException e) {
throw new IllegalStateException(e);
}
}
if (future != null && removed) {
segmentStats.eviction();
onRemoval.accept(future);
}
}
private static class SegmentStats { private static class SegmentStats {
private final LongAdder hits = new LongAdder(); private final LongAdder hits = new LongAdder();
private final LongAdder misses = new LongAdder(); private final LongAdder misses = new LongAdder();
@ -314,7 +348,7 @@ public class Cache<K, V> {
Entry<K, V> tail; Entry<K, V> tail;
// lock protecting mutations to the LRU list // lock protecting mutations to the LRU list
private ReleasableLock lruLock = new ReleasableLock(new ReentrantLock()); private final ReleasableLock lruLock = new ReleasableLock(new ReentrantLock());
/** /**
* Returns the value to which the specified key is mapped, or null if this map contains no mapping for the key. * Returns the value to which the specified key is mapped, or null if this map contains no mapping for the key.
@ -455,6 +489,19 @@ public class Cache<K, V> {
} }
} }
private final Consumer<CompletableFuture<Entry<K, V>>> invalidationConsumer = f -> {
try {
Entry<K, V> entry = f.get();
try (ReleasableLock ignored = lruLock.acquire()) {
delete(entry, RemovalNotification.RemovalReason.INVALIDATED);
}
} catch (ExecutionException e) {
// ok
} catch (InterruptedException e) {
throw new IllegalStateException(e);
}
};
/** /**
* Invalidate the association for the specified key. A removal notification will be issued for invalidated * Invalidate the association for the specified key. A removal notification will be issued for invalidated
* entries with {@link org.elasticsearch.common.cache.RemovalNotification.RemovalReason} INVALIDATED. * entries with {@link org.elasticsearch.common.cache.RemovalNotification.RemovalReason} INVALIDATED.
@ -463,18 +510,20 @@ public class Cache<K, V> {
*/ */
public void invalidate(K key) { public void invalidate(K key) {
CacheSegment<K, V> segment = getCacheSegment(key); CacheSegment<K, V> segment = getCacheSegment(key);
segment.remove(key, f -> { segment.remove(key, invalidationConsumer);
try { }
Entry<K, V> entry = f.get();
try (ReleasableLock ignored = lruLock.acquire()) { /**
delete(entry, RemovalNotification.RemovalReason.INVALIDATED); * Invalidate the entry for the specified key and value. If the value provided is not equal to the value in
} * the cache, no removal will occur. A removal notification will be issued for invalidated
} catch (ExecutionException e) { * entries with {@link org.elasticsearch.common.cache.RemovalNotification.RemovalReason} INVALIDATED.
// ok *
} catch (InterruptedException e) { * @param key the key whose mapping is to be invalidated from the cache
throw new IllegalStateException(e); * @param value the expected value that should be associated with the key
} */
}); public void invalidate(K key, V value) {
CacheSegment<K, V> segment = getCacheSegment(key);
segment.remove(key, value, invalidationConsumer);
} }
/** /**
@ -625,7 +674,7 @@ public class Cache<K, V> {
Entry<K, V> entry = current; Entry<K, V> entry = current;
if (entry != null) { if (entry != null) {
CacheSegment<K, V> segment = getCacheSegment(entry.key); CacheSegment<K, V> segment = getCacheSegment(entry.key);
segment.remove(entry.key, f -> {}); segment.remove(entry.key, entry.value, f -> {});
try (ReleasableLock ignored = lruLock.acquire()) { try (ReleasableLock ignored = lruLock.acquire()) {
current = null; current = null;
delete(entry, RemovalNotification.RemovalReason.INVALIDATED); delete(entry, RemovalNotification.RemovalReason.INVALIDATED);
@ -710,7 +759,7 @@ public class Cache<K, V> {
CacheSegment<K, V> segment = getCacheSegment(entry.key); CacheSegment<K, V> segment = getCacheSegment(entry.key);
if (segment != null) { if (segment != null) {
segment.remove(entry.key, f -> {}); segment.remove(entry.key, entry.value, f -> {});
} }
delete(entry, RemovalNotification.RemovalReason.EVICTED); delete(entry, RemovalNotification.RemovalReason.EVICTED);
} }

View File

@ -0,0 +1,115 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.common.util.concurrent;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.collect.Tuple;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
/**
* A future implementation that allows for the result to be passed to listeners waiting for
* notification. This is useful for cases where a computation is requested many times
* concurrently, but really only needs to be performed a single time. Once the computation
* has been performed the registered listeners will be notified by submitting a runnable
* for execution in the provided {@link ExecutorService}. If the computation has already
* been performed, a request to add a listener will simply result in execution of the listener
* on the calling thread.
*/
public final class ListenableFuture<V> extends BaseFuture<V> implements ActionListener<V> {
private volatile boolean done = false;
private final List<Tuple<ActionListener<V>, ExecutorService>> listeners = new ArrayList<>();
/**
* Adds a listener to this future. If the future has not yet completed, the listener will be
* notified of a response or exception in a runnable submitted to the ExecutorService provided.
* If the future has completed, the listener will be notified immediately without forking to
* a different thread.
*/
public void addListener(ActionListener<V> listener, ExecutorService executor) {
if (done) {
// run the callback directly, we don't hold the lock and don't need to fork!
notifyListener(listener, EsExecutors.newDirectExecutorService());
} else {
final boolean run;
// check done under lock since it could have been modified and protect modifications
// to the list under lock
synchronized (this) {
if (done) {
run = true;
} else {
listeners.add(new Tuple<>(listener, executor));
run = false;
}
}
if (run) {
// run the callback directly, we don't hold the lock and don't need to fork!
notifyListener(listener, EsExecutors.newDirectExecutorService());
}
}
}
@Override
protected synchronized void done() {
done = true;
listeners.forEach(t -> notifyListener(t.v1(), t.v2()));
// release references to any listeners as we no longer need them and will live
// much longer than the listeners in most cases
listeners.clear();
}
private void notifyListener(ActionListener<V> listener, ExecutorService executorService) {
try {
executorService.submit(() -> {
try {
// call get in a non-blocking fashion as we could be on a network thread
// or another thread like the scheduler, which we should never block!
V value = FutureUtils.get(this, 0L, TimeUnit.NANOSECONDS);
listener.onResponse(value);
} catch (Exception e) {
listener.onFailure(e);
}
});
} catch (Exception e) {
listener.onFailure(e);
}
}
@Override
public void onResponse(V v) {
final boolean set = set(v);
if (set == false) {
throw new IllegalStateException("did not set value, value or exception already set?");
}
}
@Override
public void onFailure(Exception e) {
final boolean set = setException(e);
if (set == false) {
throw new IllegalStateException("did not set exception, value already set or exception already set?");
}
}
}

View File

@ -457,6 +457,62 @@ public class CacheTests extends ESTestCase {
assertEquals(notifications, invalidated); assertEquals(notifications, invalidated);
} }
// randomly invalidate some cached entries, then check that a lookup for each of those and only those keys is null
public void testInvalidateWithValue() {
Cache<Integer, String> cache = CacheBuilder.<Integer, String>builder().build();
for (int i = 0; i < numberOfEntries; i++) {
cache.put(i, Integer.toString(i));
}
Set<Integer> keys = new HashSet<>();
for (Integer key : cache.keys()) {
if (rarely()) {
if (randomBoolean()) {
cache.invalidate(key, key.toString());
keys.add(key);
} else {
// invalidate with incorrect value
cache.invalidate(key, Integer.toString(key * randomIntBetween(2, 10)));
}
}
}
for (int i = 0; i < numberOfEntries; i++) {
if (keys.contains(i)) {
assertNull(cache.get(i));
} else {
assertNotNull(cache.get(i));
}
}
}
// randomly invalidate some cached entries, then check that we receive invalidate notifications for those and only
// those entries
public void testNotificationOnInvalidateWithValue() {
Set<Integer> notifications = new HashSet<>();
Cache<Integer, String> cache =
CacheBuilder.<Integer, String>builder()
.removalListener(notification -> {
assertEquals(RemovalNotification.RemovalReason.INVALIDATED, notification.getRemovalReason());
notifications.add(notification.getKey());
})
.build();
for (int i = 0; i < numberOfEntries; i++) {
cache.put(i, Integer.toString(i));
}
Set<Integer> invalidated = new HashSet<>();
for (int i = 0; i < numberOfEntries; i++) {
if (rarely()) {
if (randomBoolean()) {
cache.invalidate(i, Integer.toString(i));
invalidated.add(i);
} else {
// invalidate with incorrect value
cache.invalidate(i, Integer.toString(i * randomIntBetween(2, 10)));
}
}
}
assertEquals(notifications, invalidated);
}
// invalidate all cached entries, then check that the cache is empty // invalidate all cached entries, then check that the cache is empty
public void testInvalidateAll() { public void testInvalidateAll() {
Cache<Integer, String> cache = CacheBuilder.<Integer, String>builder().build(); Cache<Integer, String> cache = CacheBuilder.<Integer, String>builder().build();

View File

@ -0,0 +1,118 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.common.util.concurrent;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.test.ESTestCase;
import org.junit.After;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
public class ListenableFutureTests extends ESTestCase {
private ExecutorService executorService;
@After
public void stopExecutorService() throws InterruptedException {
if (executorService != null) {
terminate(executorService);
}
}
public void testListenableFutureNotifiesListeners() {
ListenableFuture<String> future = new ListenableFuture<>();
AtomicInteger notifications = new AtomicInteger(0);
final int numberOfListeners = scaledRandomIntBetween(1, 12);
for (int i = 0; i < numberOfListeners; i++) {
future.addListener(ActionListener.wrap(notifications::incrementAndGet), EsExecutors.newDirectExecutorService());
}
future.onResponse("");
assertEquals(numberOfListeners, notifications.get());
assertTrue(future.isDone());
}
public void testListenableFutureNotifiesListenersOnException() {
ListenableFuture<String> future = new ListenableFuture<>();
AtomicInteger notifications = new AtomicInteger(0);
final int numberOfListeners = scaledRandomIntBetween(1, 12);
final Exception exception = new RuntimeException();
for (int i = 0; i < numberOfListeners; i++) {
future.addListener(ActionListener.wrap(s -> fail("this should never be called"), e -> {
assertEquals(exception, e);
notifications.incrementAndGet();
}), EsExecutors.newDirectExecutorService());
}
future.onFailure(exception);
assertEquals(numberOfListeners, notifications.get());
assertTrue(future.isDone());
}
public void testConcurrentListenerRegistrationAndCompletion() throws BrokenBarrierException, InterruptedException {
final int numberOfThreads = scaledRandomIntBetween(2, 32);
final int completingThread = randomIntBetween(0, numberOfThreads - 1);
final ListenableFuture<String> future = new ListenableFuture<>();
executorService = EsExecutors.newFixed("testConcurrentListenerRegistrationAndCompletion", numberOfThreads, 1000,
EsExecutors.daemonThreadFactory("listener"), new ThreadContext(Settings.EMPTY));
final CyclicBarrier barrier = new CyclicBarrier(1 + numberOfThreads);
final CountDownLatch listenersLatch = new CountDownLatch(numberOfThreads - 1);
final AtomicInteger numResponses = new AtomicInteger(0);
final AtomicInteger numExceptions = new AtomicInteger(0);
for (int i = 0; i < numberOfThreads; i++) {
final int threadNum = i;
Thread thread = new Thread(() -> {
try {
barrier.await();
if (threadNum == completingThread) {
future.onResponse("");
} else {
future.addListener(ActionListener.wrap(s -> {
assertEquals("", s);
numResponses.incrementAndGet();
listenersLatch.countDown();
}, e -> {
logger.error("caught unexpected exception", e);
numExceptions.incrementAndGet();
listenersLatch.countDown();
}), executorService);
}
barrier.await();
} catch (InterruptedException | BrokenBarrierException e) {
throw new AssertionError(e);
}
});
thread.start();
}
barrier.await();
barrier.await();
listenersLatch.await();
assertEquals(numberOfThreads - 1, numResponses.get());
assertEquals(0, numExceptions.get());
}
}

View File

@ -410,7 +410,7 @@ public class Security extends Plugin implements ActionPlugin, IngestPlugin, Netw
final NativeRoleMappingStore nativeRoleMappingStore = new NativeRoleMappingStore(settings, client, securityIndex.get()); final NativeRoleMappingStore nativeRoleMappingStore = new NativeRoleMappingStore(settings, client, securityIndex.get());
final AnonymousUser anonymousUser = new AnonymousUser(settings); final AnonymousUser anonymousUser = new AnonymousUser(settings);
final ReservedRealm reservedRealm = new ReservedRealm(env, settings, nativeUsersStore, final ReservedRealm reservedRealm = new ReservedRealm(env, settings, nativeUsersStore,
anonymousUser, securityIndex.get(), threadPool.getThreadContext()); anonymousUser, securityIndex.get(), threadPool);
Map<String, Realm.Factory> realmFactories = new HashMap<>(InternalRealms.getFactories(threadPool, resourceWatcherService, Map<String, Realm.Factory> realmFactories = new HashMap<>(InternalRealms.getFactories(threadPool, resourceWatcherService,
getSslService(), nativeUsersStore, nativeRoleMappingStore, securityIndex.get())); getSslService(), nativeUsersStore, nativeRoleMappingStore, securityIndex.get()));
for (SecurityExtension extension : securityExtensions) { for (SecurityExtension extension : securityExtensions) {

View File

@ -93,9 +93,9 @@ public final class InternalRealms {
SecurityIndexManager securityIndex) { SecurityIndexManager securityIndex) {
Map<String, Realm.Factory> map = new HashMap<>(); Map<String, Realm.Factory> map = new HashMap<>();
map.put(FileRealmSettings.TYPE, config -> new FileRealm(config, resourceWatcherService)); map.put(FileRealmSettings.TYPE, config -> new FileRealm(config, resourceWatcherService, threadPool));
map.put(NativeRealmSettings.TYPE, config -> { map.put(NativeRealmSettings.TYPE, config -> {
final NativeRealm nativeRealm = new NativeRealm(config, nativeUsersStore); final NativeRealm nativeRealm = new NativeRealm(config, nativeUsersStore, threadPool);
securityIndex.addIndexStateListener(nativeRealm::onSecurityIndexStateChange); securityIndex.addIndexStateListener(nativeRealm::onSecurityIndexStateChange);
return nativeRealm; return nativeRealm;
}); });

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.security.authc.esnative; package org.elasticsearch.xpack.security.authc.esnative;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.security.authc.AuthenticationResult; import org.elasticsearch.xpack.core.security.authc.AuthenticationResult;
import org.elasticsearch.xpack.core.security.authc.RealmConfig; import org.elasticsearch.xpack.core.security.authc.RealmConfig;
import org.elasticsearch.xpack.core.security.authc.esnative.NativeRealmSettings; import org.elasticsearch.xpack.core.security.authc.esnative.NativeRealmSettings;
@ -24,8 +25,8 @@ public class NativeRealm extends CachingUsernamePasswordRealm {
private final NativeUsersStore userStore; private final NativeUsersStore userStore;
public NativeRealm(RealmConfig config, NativeUsersStore usersStore) { public NativeRealm(RealmConfig config, NativeUsersStore usersStore, ThreadPool threadPool) {
super(NativeRealmSettings.TYPE, config); super(NativeRealmSettings.TYPE, config, threadPool);
this.userStore = usersStore; this.userStore = usersStore;
} }

View File

@ -14,8 +14,8 @@ import org.elasticsearch.common.settings.SecureSetting;
import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.env.Environment; import org.elasticsearch.env.Environment;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.XPackSettings; import org.elasticsearch.xpack.core.XPackSettings;
import org.elasticsearch.xpack.core.security.SecurityField; import org.elasticsearch.xpack.core.security.SecurityField;
import org.elasticsearch.xpack.core.security.authc.AuthenticationResult; import org.elasticsearch.xpack.core.security.authc.AuthenticationResult;
@ -66,8 +66,8 @@ public class ReservedRealm extends CachingUsernamePasswordRealm {
private final SecurityIndexManager securityIndex; private final SecurityIndexManager securityIndex;
public ReservedRealm(Environment env, Settings settings, NativeUsersStore nativeUsersStore, AnonymousUser anonymousUser, public ReservedRealm(Environment env, Settings settings, NativeUsersStore nativeUsersStore, AnonymousUser anonymousUser,
SecurityIndexManager securityIndex, ThreadContext threadContext) { SecurityIndexManager securityIndex, ThreadPool threadPool) {
super(TYPE, new RealmConfig(TYPE, Settings.EMPTY, settings, env, threadContext)); super(TYPE, new RealmConfig(TYPE, Settings.EMPTY, settings, env, threadPool.getThreadContext()), threadPool);
this.nativeUsersStore = nativeUsersStore; this.nativeUsersStore = nativeUsersStore;
this.realmEnabled = XPackSettings.RESERVED_REALM_ENABLED_SETTING.get(settings); this.realmEnabled = XPackSettings.RESERVED_REALM_ENABLED_SETTING.get(settings);
this.anonymousUser = anonymousUser; this.anonymousUser = anonymousUser;

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.security.authc.file; package org.elasticsearch.xpack.security.authc.file;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.watcher.ResourceWatcherService; import org.elasticsearch.watcher.ResourceWatcherService;
import org.elasticsearch.xpack.core.security.authc.AuthenticationResult; import org.elasticsearch.xpack.core.security.authc.AuthenticationResult;
import org.elasticsearch.xpack.core.security.authc.RealmConfig; import org.elasticsearch.xpack.core.security.authc.RealmConfig;
@ -21,13 +22,13 @@ public class FileRealm extends CachingUsernamePasswordRealm {
private final FileUserPasswdStore userPasswdStore; private final FileUserPasswdStore userPasswdStore;
private final FileUserRolesStore userRolesStore; private final FileUserRolesStore userRolesStore;
public FileRealm(RealmConfig config, ResourceWatcherService watcherService) { public FileRealm(RealmConfig config, ResourceWatcherService watcherService, ThreadPool threadPool) {
this(config, new FileUserPasswdStore(config, watcherService), new FileUserRolesStore(config, watcherService)); this(config, new FileUserPasswdStore(config, watcherService), new FileUserRolesStore(config, watcherService), threadPool);
} }
// pkg private for testing // pkg private for testing
FileRealm(RealmConfig config, FileUserPasswdStore userPasswdStore, FileUserRolesStore userRolesStore) { FileRealm(RealmConfig config, FileUserPasswdStore userPasswdStore, FileUserRolesStore userRolesStore, ThreadPool threadPool) {
super(FileRealmSettings.TYPE, config); super(FileRealmSettings.TYPE, config, threadPool);
this.userPasswdStore = userPasswdStore; this.userPasswdStore = userPasswdStore;
userPasswdStore.addListener(this::expireAll); userPasswdStore.addListener(this::expireAll);
this.userRolesStore = userRolesStore; this.userRolesStore = userRolesStore;

View File

@ -67,7 +67,7 @@ public final class LdapRealm extends CachingUsernamePasswordRealm {
// pkg private for testing // pkg private for testing
LdapRealm(String type, RealmConfig config, SessionFactory sessionFactory, LdapRealm(String type, RealmConfig config, SessionFactory sessionFactory,
UserRoleMapper roleMapper, ThreadPool threadPool) { UserRoleMapper roleMapper, ThreadPool threadPool) {
super(type, config); super(type, config, threadPool);
this.sessionFactory = sessionFactory; this.sessionFactory = sessionFactory;
this.roleMapper = roleMapper; this.roleMapper = roleMapper;
this.threadPool = threadPool; this.threadPool = threadPool;

View File

@ -5,11 +5,15 @@
*/ */
package org.elasticsearch.xpack.security.authc.support; package org.elasticsearch.xpack.security.authc.support;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.cache.Cache; import org.elasticsearch.common.cache.Cache;
import org.elasticsearch.common.cache.CacheBuilder; import org.elasticsearch.common.cache.CacheBuilder;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.ListenableFuture;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.security.authc.AuthenticationResult; import org.elasticsearch.xpack.core.security.authc.AuthenticationResult;
import org.elasticsearch.xpack.core.security.authc.AuthenticationToken; import org.elasticsearch.xpack.core.security.authc.AuthenticationToken;
import org.elasticsearch.xpack.core.security.authc.RealmConfig; import org.elasticsearch.xpack.core.security.authc.RealmConfig;
@ -21,18 +25,21 @@ import org.elasticsearch.xpack.core.security.user.User;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicBoolean;
public abstract class CachingUsernamePasswordRealm extends UsernamePasswordRealm implements CachingRealm { public abstract class CachingUsernamePasswordRealm extends UsernamePasswordRealm implements CachingRealm {
private final Cache<String, UserWithHash> cache; private final Cache<String, ListenableFuture<Tuple<AuthenticationResult, UserWithHash>>> cache;
private final ThreadPool threadPool;
final Hasher hasher; final Hasher hasher;
protected CachingUsernamePasswordRealm(String type, RealmConfig config) { protected CachingUsernamePasswordRealm(String type, RealmConfig config, ThreadPool threadPool) {
super(type, config); super(type, config);
hasher = Hasher.resolve(CachingUsernamePasswordRealmSettings.CACHE_HASH_ALGO_SETTING.get(config.settings()), Hasher.SSHA256); hasher = Hasher.resolve(CachingUsernamePasswordRealmSettings.CACHE_HASH_ALGO_SETTING.get(config.settings()), Hasher.SSHA256);
this.threadPool = threadPool;
TimeValue ttl = CachingUsernamePasswordRealmSettings.CACHE_TTL_SETTING.get(config.settings()); TimeValue ttl = CachingUsernamePasswordRealmSettings.CACHE_TTL_SETTING.get(config.settings());
if (ttl.getNanos() > 0) { if (ttl.getNanos() > 0) {
cache = CacheBuilder.<String, UserWithHash>builder() cache = CacheBuilder.<String, ListenableFuture<Tuple<AuthenticationResult, UserWithHash>>>builder()
.setExpireAfterWrite(ttl) .setExpireAfterWrite(ttl)
.setMaximumWeight(CachingUsernamePasswordRealmSettings.CACHE_MAX_USERS_SETTING.get(config.settings())) .setMaximumWeight(CachingUsernamePasswordRealmSettings.CACHE_MAX_USERS_SETTING.get(config.settings()))
.build(); .build();
@ -78,74 +85,95 @@ public abstract class CachingUsernamePasswordRealm extends UsernamePasswordRealm
} }
private void authenticateWithCache(UsernamePasswordToken token, ActionListener<AuthenticationResult> listener) { private void authenticateWithCache(UsernamePasswordToken token, ActionListener<AuthenticationResult> listener) {
UserWithHash userWithHash = cache.get(token.principal()); try {
if (userWithHash == null) { final SetOnce<User> authenticatedUser = new SetOnce<>();
if (logger.isDebugEnabled()) { final AtomicBoolean createdAndStartedFuture = new AtomicBoolean(false);
logger.debug("user [{}] not found in cache for realm [{}], proceeding with normal authentication", final ListenableFuture<Tuple<AuthenticationResult, UserWithHash>> future = cache.computeIfAbsent(token.principal(), k -> {
token.principal(), name()); final ListenableFuture<Tuple<AuthenticationResult, UserWithHash>> created = new ListenableFuture<>();
} if (createdAndStartedFuture.compareAndSet(false, true) == false) {
doAuthenticateAndCache(token, ActionListener.wrap((result) -> { throw new IllegalStateException("something else already started this. how?");
if (result.isAuthenticated()) {
final User user = result.getUser();
logger.debug("realm [{}] authenticated user [{}], with roles [{}]", name(), token.principal(), user.roles());
} }
listener.onResponse(result); return created;
}, listener::onFailure)); });
} else if (userWithHash.hasHash()) {
if (userWithHash.verify(token.credentials())) { if (createdAndStartedFuture.get()) {
if (userWithHash.user.enabled()) { doAuthenticate(token, ActionListener.wrap(result -> {
User user = userWithHash.user;
logger.debug("realm [{}] authenticated user [{}], with roles [{}]", name(), token.principal(), user.roles());
listener.onResponse(AuthenticationResult.success(user));
} else {
// We successfully authenticated, but the cached user is disabled.
// Reload the primary record to check whether the user is still disabled
cache.invalidate(token.principal());
doAuthenticateAndCache(token, ActionListener.wrap((result) -> {
if (result.isAuthenticated()) {
final User user = result.getUser();
logger.debug("realm [{}] authenticated user [{}] (enabled:{}), with roles [{}]", name(), token.principal(),
user.enabled(), user.roles());
}
listener.onResponse(result);
}, listener::onFailure));
}
} else {
cache.invalidate(token.principal());
doAuthenticateAndCache(token, ActionListener.wrap((result) -> {
if (result.isAuthenticated()) { if (result.isAuthenticated()) {
final User user = result.getUser(); final User user = result.getUser();
logger.debug("cached user's password changed. realm [{}] authenticated user [{}], with roles [{}]", authenticatedUser.set(user);
name(), token.principal(), user.roles()); final UserWithHash userWithHash = new UserWithHash(user, token.credentials(), hasher);
future.onResponse(new Tuple<>(result, userWithHash));
} else {
future.onResponse(new Tuple<>(result, null));
} }
listener.onResponse(result); }, future::onFailure));
}, listener::onFailure));
} }
} else {
cache.invalidate(token.principal()); future.addListener(ActionListener.wrap(tuple -> {
doAuthenticateAndCache(token, ActionListener.wrap((result) -> { if (tuple != null) {
if (result.isAuthenticated()) { final UserWithHash userWithHash = tuple.v2();
final User user = result.getUser(); final boolean performedAuthentication = createdAndStartedFuture.get() && userWithHash != null &&
logger.debug("cached user came from a lookup and could not be used for authentication. " + tuple.v2().user == authenticatedUser.get();
"realm [{}] authenticated user [{}] with roles [{}]", name(), token.principal(), user.roles()); handleResult(future, createdAndStartedFuture.get(), performedAuthentication, token, tuple, listener);
} else {
handleFailure(future, createdAndStartedFuture.get(), token, new IllegalStateException("unknown error authenticating"),
listener);
} }
listener.onResponse(result); }, e -> handleFailure(future, createdAndStartedFuture.get(), token, e, listener)),
}, listener::onFailure)); threadPool.executor(ThreadPool.Names.GENERIC));
} catch (ExecutionException e) {
listener.onResponse(AuthenticationResult.unsuccessful("", e));
} }
} }
private void doAuthenticateAndCache(UsernamePasswordToken token, ActionListener<AuthenticationResult> listener) { private void handleResult(ListenableFuture<Tuple<AuthenticationResult, UserWithHash>> future, boolean createdAndStartedFuture,
ActionListener<AuthenticationResult> wrapped = ActionListener.wrap((result) -> { boolean performedAuthentication, UsernamePasswordToken token,
Objects.requireNonNull(result, "AuthenticationResult cannot be null"); Tuple<AuthenticationResult, UserWithHash> result, ActionListener<AuthenticationResult> listener) {
if (result.getStatus() == AuthenticationResult.Status.SUCCESS) { final AuthenticationResult authResult = result.v1();
UserWithHash userWithHash = new UserWithHash(result.getUser(), token.credentials(), hasher); if (authResult == null) {
// it doesn't matter if we already computed it elsewhere // this was from a lookup; clear and redo
cache.put(token.principal(), userWithHash); cache.invalidate(token.principal(), future);
authenticateWithCache(token, listener);
} else if (authResult.isAuthenticated()) {
if (performedAuthentication) {
listener.onResponse(authResult);
} else {
UserWithHash userWithHash = result.v2();
if (userWithHash.verify(token.credentials())) {
if (userWithHash.user.enabled()) {
User user = userWithHash.user;
logger.debug("realm [{}] authenticated user [{}], with roles [{}]",
name(), token.principal(), user.roles());
listener.onResponse(AuthenticationResult.success(user));
} else {
// re-auth to see if user has been enabled
cache.invalidate(token.principal(), future);
authenticateWithCache(token, listener);
}
} else {
// could be a password change?
cache.invalidate(token.principal(), future);
authenticateWithCache(token, listener);
}
} }
listener.onResponse(result); } else {
}, listener::onFailure); cache.invalidate(token.principal(), future);
if (createdAndStartedFuture) {
listener.onResponse(authResult);
} else {
authenticateWithCache(token, listener);
}
}
}
doAuthenticate(token, wrapped); private void handleFailure(ListenableFuture<Tuple<AuthenticationResult, UserWithHash>> future, boolean createdAndStarted,
UsernamePasswordToken token, Exception e, ActionListener<AuthenticationResult> listener) {
cache.invalidate(token.principal(), future);
if (createdAndStarted) {
listener.onFailure(e);
} else {
authenticateWithCache(token, listener);
}
} }
@Override @Override
@ -160,29 +188,34 @@ public abstract class CachingUsernamePasswordRealm extends UsernamePasswordRealm
@Override @Override
public final void lookupUser(String username, ActionListener<User> listener) { public final void lookupUser(String username, ActionListener<User> listener) {
if (cache != null) { if (cache != null) {
UserWithHash withHash = cache.get(username); try {
if (withHash == null) { ListenableFuture<Tuple<AuthenticationResult, UserWithHash>> future = cache.computeIfAbsent(username, key -> {
try { ListenableFuture<Tuple<AuthenticationResult, UserWithHash>> created = new ListenableFuture<>();
doLookupUser(username, ActionListener.wrap((user) -> { doLookupUser(username, ActionListener.wrap(user -> {
Runnable action = () -> listener.onResponse(null);
if (user != null) { if (user != null) {
UserWithHash userWithHash = new UserWithHash(user, null, null); UserWithHash userWithHash = new UserWithHash(user, null, null);
try { created.onResponse(new Tuple<>(null, userWithHash));
// computeIfAbsent is used here to avoid overwriting a value from a concurrent authenticate call as it } else {
// contains the password hash, which provides a performance boost and we shouldn't just erase that created.onResponse(new Tuple<>(null, null));
cache.computeIfAbsent(username, (n) -> userWithHash);
action = () -> listener.onResponse(userWithHash.user);
} catch (ExecutionException e) {
action = () -> listener.onFailure(e);
}
} }
action.run(); }, created::onFailure));
}, listener::onFailure)); return created;
} catch (Exception e) { });
listener.onFailure(e);
} future.addListener(ActionListener.wrap(tuple -> {
} else { if (tuple != null) {
listener.onResponse(withHash.user); if (tuple.v2() == null) {
cache.invalidate(username, future);
listener.onResponse(null);
} else {
listener.onResponse(tuple.v2().user);
}
} else {
listener.onResponse(null);
}
}, listener::onFailure), threadPool.executor(ThreadPool.Names.GENERIC));
} catch (ExecutionException e) {
listener.onFailure(e);
} }
} else { } else {
doLookupUser(username, listener); doLookupUser(username, listener);
@ -192,12 +225,12 @@ public abstract class CachingUsernamePasswordRealm extends UsernamePasswordRealm
protected abstract void doLookupUser(String username, ActionListener<User> listener); protected abstract void doLookupUser(String username, ActionListener<User> listener);
private static class UserWithHash { private static class UserWithHash {
User user; final User user;
char[] hash; final char[] hash;
Hasher hasher; final Hasher hasher;
UserWithHash(User user, SecureString password, Hasher hasher) { UserWithHash(User user, SecureString password, Hasher hasher) {
this.user = user; this.user = Objects.requireNonNull(user);
this.hash = password == null ? null : hasher.hash(password); this.hash = password == null ? null : hasher.hash(password);
this.hasher = hasher; this.hasher = hasher;
} }
@ -205,9 +238,5 @@ public abstract class CachingUsernamePasswordRealm extends UsernamePasswordRealm
boolean verify(SecureString password) { boolean verify(SecureString password) {
return hash != null && hasher.verify(password, hash); return hash != null && hasher.verify(password, hash);
} }
boolean hasHash() {
return hash != null;
}
} }
} }

View File

@ -13,9 +13,9 @@ import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.env.Environment; import org.elasticsearch.env.Environment;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService; import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.security.action.user.GetUsersRequest; import org.elasticsearch.xpack.core.security.action.user.GetUsersRequest;
@ -28,6 +28,7 @@ import org.elasticsearch.xpack.security.authc.esnative.NativeUsersStore;
import org.elasticsearch.xpack.security.authc.esnative.ReservedRealm; import org.elasticsearch.xpack.security.authc.esnative.ReservedRealm;
import org.elasticsearch.xpack.security.authc.esnative.ReservedRealmTests; import org.elasticsearch.xpack.security.authc.esnative.ReservedRealmTests;
import org.elasticsearch.xpack.security.support.SecurityIndexManager; import org.elasticsearch.xpack.security.support.SecurityIndexManager;
import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.mockito.invocation.InvocationOnMock; import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
@ -62,6 +63,7 @@ public class TransportGetUsersActionTests extends ESTestCase {
private boolean anonymousEnabled; private boolean anonymousEnabled;
private Settings settings; private Settings settings;
private ThreadPool threadPool;
@Before @Before
public void maybeEnableAnonymous() { public void maybeEnableAnonymous() {
@ -71,6 +73,14 @@ public class TransportGetUsersActionTests extends ESTestCase {
} else { } else {
settings = Settings.EMPTY; settings = Settings.EMPTY;
} }
threadPool = new TestThreadPool("TransportGetUsersActionTests");
}
@After
public void terminateThreadPool() throws InterruptedException {
if (threadPool != null) {
terminate(threadPool);
}
} }
public void testAnonymousUser() { public void testAnonymousUser() {
@ -79,10 +89,10 @@ public class TransportGetUsersActionTests extends ESTestCase {
when(securityIndex.isAvailable()).thenReturn(true); when(securityIndex.isAvailable()).thenReturn(true);
AnonymousUser anonymousUser = new AnonymousUser(settings); AnonymousUser anonymousUser = new AnonymousUser(settings);
ReservedRealm reservedRealm = ReservedRealm reservedRealm =
new ReservedRealm(mock(Environment.class), settings, usersStore, anonymousUser, securityIndex, new ThreadContext(Settings.EMPTY)); new ReservedRealm(mock(Environment.class), settings, usersStore, anonymousUser, securityIndex, threadPool);
TransportService transportService = new TransportService(Settings.EMPTY, null, null, TransportService.NOOP_TRANSPORT_INTERCEPTOR, TransportService transportService = new TransportService(Settings.EMPTY, null, null, TransportService.NOOP_TRANSPORT_INTERCEPTOR,
x -> null, null, Collections.emptySet()); x -> null, null, Collections.emptySet());
TransportGetUsersAction action = new TransportGetUsersAction(Settings.EMPTY, mock(ThreadPool.class), mock(ActionFilters.class), TransportGetUsersAction action = new TransportGetUsersAction(Settings.EMPTY, threadPool, mock(ActionFilters.class),
mock(IndexNameExpressionResolver.class), usersStore, transportService, reservedRealm); mock(IndexNameExpressionResolver.class), usersStore, transportService, reservedRealm);
GetUsersRequest request = new GetUsersRequest(); GetUsersRequest request = new GetUsersRequest();
@ -117,7 +127,7 @@ public class TransportGetUsersActionTests extends ESTestCase {
NativeUsersStore usersStore = mock(NativeUsersStore.class); NativeUsersStore usersStore = mock(NativeUsersStore.class);
TransportService transportService = new TransportService(Settings.EMPTY, null, null, TransportService.NOOP_TRANSPORT_INTERCEPTOR, TransportService transportService = new TransportService(Settings.EMPTY, null, null, TransportService.NOOP_TRANSPORT_INTERCEPTOR,
x -> null, null, Collections.emptySet()); x -> null, null, Collections.emptySet());
TransportGetUsersAction action = new TransportGetUsersAction(Settings.EMPTY, mock(ThreadPool.class), mock(ActionFilters.class), TransportGetUsersAction action = new TransportGetUsersAction(Settings.EMPTY, threadPool, mock(ActionFilters.class),
mock(IndexNameExpressionResolver.class), usersStore, transportService, mock(ReservedRealm.class)); mock(IndexNameExpressionResolver.class), usersStore, transportService, mock(ReservedRealm.class));
GetUsersRequest request = new GetUsersRequest(); GetUsersRequest request = new GetUsersRequest();
@ -151,7 +161,7 @@ public class TransportGetUsersActionTests extends ESTestCase {
ReservedRealmTests.mockGetAllReservedUserInfo(usersStore, Collections.emptyMap()); ReservedRealmTests.mockGetAllReservedUserInfo(usersStore, Collections.emptyMap());
ReservedRealm reservedRealm = ReservedRealm reservedRealm =
new ReservedRealm(mock(Environment.class), settings, usersStore, new AnonymousUser(settings), securityIndex, new ThreadContext(Settings.EMPTY)); new ReservedRealm(mock(Environment.class), settings, usersStore, new AnonymousUser(settings), securityIndex, threadPool);
PlainActionFuture<Collection<User>> userFuture = new PlainActionFuture<>(); PlainActionFuture<Collection<User>> userFuture = new PlainActionFuture<>();
reservedRealm.users(userFuture); reservedRealm.users(userFuture);
final Collection<User> allReservedUsers = userFuture.actionGet(); final Collection<User> allReservedUsers = userFuture.actionGet();
@ -160,7 +170,7 @@ public class TransportGetUsersActionTests extends ESTestCase {
final List<String> names = reservedUsers.stream().map(User::principal).collect(Collectors.toList()); final List<String> names = reservedUsers.stream().map(User::principal).collect(Collectors.toList());
TransportService transportService = new TransportService(Settings.EMPTY, null, null, TransportService.NOOP_TRANSPORT_INTERCEPTOR, TransportService transportService = new TransportService(Settings.EMPTY, null, null, TransportService.NOOP_TRANSPORT_INTERCEPTOR,
x -> null, null, Collections.emptySet()); x -> null, null, Collections.emptySet());
TransportGetUsersAction action = new TransportGetUsersAction(Settings.EMPTY, mock(ThreadPool.class), mock(ActionFilters.class), TransportGetUsersAction action = new TransportGetUsersAction(Settings.EMPTY, threadPool, mock(ActionFilters.class),
mock(IndexNameExpressionResolver.class), usersStore, transportService, reservedRealm); mock(IndexNameExpressionResolver.class), usersStore, transportService, reservedRealm);
logger.error("names {}", names); logger.error("names {}", names);
@ -197,10 +207,10 @@ public class TransportGetUsersActionTests extends ESTestCase {
when(securityIndex.isAvailable()).thenReturn(true); when(securityIndex.isAvailable()).thenReturn(true);
ReservedRealmTests.mockGetAllReservedUserInfo(usersStore, Collections.emptyMap()); ReservedRealmTests.mockGetAllReservedUserInfo(usersStore, Collections.emptyMap());
ReservedRealm reservedRealm = new ReservedRealm(mock(Environment.class), settings, usersStore, new AnonymousUser(settings), ReservedRealm reservedRealm = new ReservedRealm(mock(Environment.class), settings, usersStore, new AnonymousUser(settings),
securityIndex, new ThreadContext(Settings.EMPTY)); securityIndex, threadPool);
TransportService transportService = new TransportService(Settings.EMPTY, null, null, TransportService.NOOP_TRANSPORT_INTERCEPTOR, TransportService transportService = new TransportService(Settings.EMPTY, null, null, TransportService.NOOP_TRANSPORT_INTERCEPTOR,
x -> null, null, Collections.emptySet()); x -> null, null, Collections.emptySet());
TransportGetUsersAction action = new TransportGetUsersAction(Settings.EMPTY, mock(ThreadPool.class), mock(ActionFilters.class), TransportGetUsersAction action = new TransportGetUsersAction(Settings.EMPTY, threadPool, mock(ActionFilters.class),
mock(IndexNameExpressionResolver.class), usersStore, transportService, reservedRealm); mock(IndexNameExpressionResolver.class), usersStore, transportService, reservedRealm);
GetUsersRequest request = new GetUsersRequest(); GetUsersRequest request = new GetUsersRequest();
@ -247,7 +257,7 @@ public class TransportGetUsersActionTests extends ESTestCase {
NativeUsersStore usersStore = mock(NativeUsersStore.class); NativeUsersStore usersStore = mock(NativeUsersStore.class);
TransportService transportService = new TransportService(Settings.EMPTY, null, null, TransportService.NOOP_TRANSPORT_INTERCEPTOR, TransportService transportService = new TransportService(Settings.EMPTY, null, null, TransportService.NOOP_TRANSPORT_INTERCEPTOR,
x -> null, null, Collections.emptySet()); x -> null, null, Collections.emptySet());
TransportGetUsersAction action = new TransportGetUsersAction(Settings.EMPTY, mock(ThreadPool.class), mock(ActionFilters.class), TransportGetUsersAction action = new TransportGetUsersAction(Settings.EMPTY, threadPool, mock(ActionFilters.class),
mock(IndexNameExpressionResolver.class), usersStore, transportService, mock(ReservedRealm.class)); mock(IndexNameExpressionResolver.class), usersStore, transportService, mock(ReservedRealm.class));
GetUsersRequest request = new GetUsersRequest(); GetUsersRequest request = new GetUsersRequest();
@ -295,7 +305,7 @@ public class TransportGetUsersActionTests extends ESTestCase {
NativeUsersStore usersStore = mock(NativeUsersStore.class); NativeUsersStore usersStore = mock(NativeUsersStore.class);
TransportService transportService = new TransportService(Settings.EMPTY, null, null, TransportService.NOOP_TRANSPORT_INTERCEPTOR, TransportService transportService = new TransportService(Settings.EMPTY, null, null, TransportService.NOOP_TRANSPORT_INTERCEPTOR,
x -> null, null, Collections.emptySet()); x -> null, null, Collections.emptySet());
TransportGetUsersAction action = new TransportGetUsersAction(Settings.EMPTY, mock(ThreadPool.class), mock(ActionFilters.class), TransportGetUsersAction action = new TransportGetUsersAction(Settings.EMPTY, threadPool, mock(ActionFilters.class),
mock(IndexNameExpressionResolver.class), usersStore, transportService, mock(ReservedRealm.class)); mock(IndexNameExpressionResolver.class), usersStore, transportService, mock(ReservedRealm.class));
GetUsersRequest request = new GetUsersRequest(); GetUsersRequest request = new GetUsersRequest();

View File

@ -121,14 +121,16 @@ public class TransportPutUserActionTests extends ESTestCase {
when(securityIndex.isAvailable()).thenReturn(true); when(securityIndex.isAvailable()).thenReturn(true);
ReservedRealmTests.mockGetAllReservedUserInfo(usersStore, Collections.emptyMap()); ReservedRealmTests.mockGetAllReservedUserInfo(usersStore, Collections.emptyMap());
Settings settings = Settings.builder().put("path.home", createTempDir()).build(); Settings settings = Settings.builder().put("path.home", createTempDir()).build();
final ThreadPool threadPool = mock(ThreadPool.class);
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(settings));
ReservedRealm reservedRealm = new ReservedRealm(TestEnvironment.newEnvironment(settings), settings, usersStore, ReservedRealm reservedRealm = new ReservedRealm(TestEnvironment.newEnvironment(settings), settings, usersStore,
new AnonymousUser(settings), securityIndex, new ThreadContext(settings)); new AnonymousUser(settings), securityIndex, threadPool);
PlainActionFuture<Collection<User>> userFuture = new PlainActionFuture<>(); PlainActionFuture<Collection<User>> userFuture = new PlainActionFuture<>();
reservedRealm.users(userFuture); reservedRealm.users(userFuture);
final User reserved = randomFrom(userFuture.actionGet().toArray(new User[0])); final User reserved = randomFrom(userFuture.actionGet().toArray(new User[0]));
TransportService transportService = new TransportService(Settings.EMPTY, null, null, TransportService.NOOP_TRANSPORT_INTERCEPTOR, TransportService transportService = new TransportService(Settings.EMPTY, null, null, TransportService.NOOP_TRANSPORT_INTERCEPTOR,
x -> null, null, Collections.emptySet()); x -> null, null, Collections.emptySet());
TransportPutUserAction action = new TransportPutUserAction(Settings.EMPTY, mock(ThreadPool.class), mock(ActionFilters.class), TransportPutUserAction action = new TransportPutUserAction(Settings.EMPTY, threadPool, mock(ActionFilters.class),
mock(IndexNameExpressionResolver.class), usersStore, transportService); mock(IndexNameExpressionResolver.class), usersStore, transportService);
PutUserRequest request = new PutUserRequest(); PutUserRequest request = new PutUserRequest();

View File

@ -11,6 +11,7 @@ import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.env.TestEnvironment; import org.elasticsearch.env.TestEnvironment;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.security.authc.RealmConfig; import org.elasticsearch.xpack.core.security.authc.RealmConfig;
import org.elasticsearch.xpack.security.support.SecurityIndexManager; import org.elasticsearch.xpack.security.support.SecurityIndexManager;
@ -18,6 +19,7 @@ import java.util.concurrent.atomic.AtomicInteger;
import static org.elasticsearch.xpack.security.test.SecurityTestUtils.getClusterIndexHealth; import static org.elasticsearch.xpack.security.test.SecurityTestUtils.getClusterIndexHealth;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class NativeRealmTests extends ESTestCase { public class NativeRealmTests extends ESTestCase {
@ -26,12 +28,15 @@ public class NativeRealmTests extends ESTestCase {
} }
public void testCacheClearOnIndexHealthChange() { public void testCacheClearOnIndexHealthChange() {
final ThreadPool threadPool = mock(ThreadPool.class);
final ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
when(threadPool.getThreadContext()).thenReturn(threadContext);
final AtomicInteger numInvalidation = new AtomicInteger(0); final AtomicInteger numInvalidation = new AtomicInteger(0);
int expectedInvalidation = 0; int expectedInvalidation = 0;
Settings settings = Settings.builder().put("path.home", createTempDir()).build(); Settings settings = Settings.builder().put("path.home", createTempDir()).build();
RealmConfig config = new RealmConfig("native", Settings.EMPTY, settings, TestEnvironment.newEnvironment(settings), RealmConfig config = new RealmConfig("native", Settings.EMPTY, settings, TestEnvironment.newEnvironment(settings),
new ThreadContext(settings)); new ThreadContext(settings));
final NativeRealm nativeRealm = new NativeRealm(config, mock(NativeUsersStore.class)) { final NativeRealm nativeRealm = new NativeRealm(config, mock(NativeUsersStore.class), threadPool) {
@Override @Override
void clearCache() { void clearCache() {
numInvalidation.incrementAndGet(); numInvalidation.incrementAndGet();

View File

@ -15,6 +15,7 @@ import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.env.Environment; import org.elasticsearch.env.Environment;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.XPackSettings; import org.elasticsearch.xpack.core.XPackSettings;
import org.elasticsearch.xpack.core.security.authc.AuthenticationResult; import org.elasticsearch.xpack.core.security.authc.AuthenticationResult;
import org.elasticsearch.xpack.core.security.authc.esnative.ClientReservedRealm; import org.elasticsearch.xpack.core.security.authc.esnative.ClientReservedRealm;
@ -63,6 +64,7 @@ public class ReservedRealmTests extends ESTestCase {
private static final SecureString EMPTY_PASSWORD = new SecureString("".toCharArray()); private static final SecureString EMPTY_PASSWORD = new SecureString("".toCharArray());
private NativeUsersStore usersStore; private NativeUsersStore usersStore;
private SecurityIndexManager securityIndex; private SecurityIndexManager securityIndex;
private ThreadPool threadPool;
@Before @Before
public void setupMocks() throws Exception { public void setupMocks() throws Exception {
@ -71,6 +73,8 @@ public class ReservedRealmTests extends ESTestCase {
when(securityIndex.isAvailable()).thenReturn(true); when(securityIndex.isAvailable()).thenReturn(true);
when(securityIndex.checkMappingVersion(any())).thenReturn(true); when(securityIndex.checkMappingVersion(any())).thenReturn(true);
mockGetAllReservedUserInfo(usersStore, Collections.emptyMap()); mockGetAllReservedUserInfo(usersStore, Collections.emptyMap());
threadPool = mock(ThreadPool.class);
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
} }
public void testReservedUserEmptyPasswordAuthenticationFails() throws Throwable { public void testReservedUserEmptyPasswordAuthenticationFails() throws Throwable {
@ -78,7 +82,7 @@ public class ReservedRealmTests extends ESTestCase {
UsernamesField.BEATS_NAME); UsernamesField.BEATS_NAME);
final ReservedRealm reservedRealm = new ReservedRealm(mock(Environment.class), Settings.EMPTY, usersStore, final ReservedRealm reservedRealm = new ReservedRealm(mock(Environment.class), Settings.EMPTY, usersStore,
new AnonymousUser(Settings.EMPTY), securityIndex, new ThreadContext(Settings.EMPTY)); new AnonymousUser(Settings.EMPTY), securityIndex, threadPool);
PlainActionFuture<AuthenticationResult> listener = new PlainActionFuture<>(); PlainActionFuture<AuthenticationResult> listener = new PlainActionFuture<>();
@ -94,7 +98,7 @@ public class ReservedRealmTests extends ESTestCase {
} }
final ReservedRealm reservedRealm = final ReservedRealm reservedRealm =
new ReservedRealm(mock(Environment.class), settings, usersStore, new ReservedRealm(mock(Environment.class), settings, usersStore,
new AnonymousUser(settings), securityIndex, new ThreadContext(Settings.EMPTY)); new AnonymousUser(settings), securityIndex, threadPool);
final User expected = randomReservedUser(true); final User expected = randomReservedUser(true);
final String principal = expected.principal(); final String principal = expected.principal();
@ -116,7 +120,7 @@ public class ReservedRealmTests extends ESTestCase {
private void verifySuccessfulAuthentication(boolean enabled) throws Exception { private void verifySuccessfulAuthentication(boolean enabled) throws Exception {
final ReservedRealm reservedRealm = new ReservedRealm(mock(Environment.class), Settings.EMPTY, usersStore, final ReservedRealm reservedRealm = new ReservedRealm(mock(Environment.class), Settings.EMPTY, usersStore,
new AnonymousUser(Settings.EMPTY), securityIndex, new ThreadContext(Settings.EMPTY)); new AnonymousUser(Settings.EMPTY), securityIndex, threadPool);
final User expectedUser = randomReservedUser(enabled); final User expectedUser = randomReservedUser(enabled);
final String principal = expectedUser.principal(); final String principal = expectedUser.principal();
final SecureString newPassword = new SecureString("foobar".toCharArray()); final SecureString newPassword = new SecureString("foobar".toCharArray());
@ -157,7 +161,7 @@ public class ReservedRealmTests extends ESTestCase {
public void testLookup() throws Exception { public void testLookup() throws Exception {
final ReservedRealm reservedRealm = final ReservedRealm reservedRealm =
new ReservedRealm(mock(Environment.class), Settings.EMPTY, usersStore, new ReservedRealm(mock(Environment.class), Settings.EMPTY, usersStore,
new AnonymousUser(Settings.EMPTY), securityIndex, new ThreadContext(Settings.EMPTY)); new AnonymousUser(Settings.EMPTY), securityIndex, threadPool);
final User expectedUser = randomReservedUser(true); final User expectedUser = randomReservedUser(true);
final String principal = expectedUser.principal(); final String principal = expectedUser.principal();
@ -182,7 +186,7 @@ public class ReservedRealmTests extends ESTestCase {
Settings settings = Settings.builder().put(XPackSettings.RESERVED_REALM_ENABLED_SETTING.getKey(), false).build(); Settings settings = Settings.builder().put(XPackSettings.RESERVED_REALM_ENABLED_SETTING.getKey(), false).build();
final ReservedRealm reservedRealm = final ReservedRealm reservedRealm =
new ReservedRealm(mock(Environment.class), settings, usersStore, new AnonymousUser(settings), new ReservedRealm(mock(Environment.class), settings, usersStore, new AnonymousUser(settings),
securityIndex, new ThreadContext(Settings.EMPTY)); securityIndex, threadPool);
final User expectedUser = randomReservedUser(true); final User expectedUser = randomReservedUser(true);
final String principal = expectedUser.principal(); final String principal = expectedUser.principal();
@ -196,7 +200,7 @@ public class ReservedRealmTests extends ESTestCase {
public void testLookupThrows() throws Exception { public void testLookupThrows() throws Exception {
final ReservedRealm reservedRealm = final ReservedRealm reservedRealm =
new ReservedRealm(mock(Environment.class), Settings.EMPTY, usersStore, new ReservedRealm(mock(Environment.class), Settings.EMPTY, usersStore,
new AnonymousUser(Settings.EMPTY), securityIndex, new ThreadContext(Settings.EMPTY)); new AnonymousUser(Settings.EMPTY), securityIndex, threadPool);
final User expectedUser = randomReservedUser(true); final User expectedUser = randomReservedUser(true);
final String principal = expectedUser.principal(); final String principal = expectedUser.principal();
when(securityIndex.indexExists()).thenReturn(true); when(securityIndex.indexExists()).thenReturn(true);
@ -243,7 +247,7 @@ public class ReservedRealmTests extends ESTestCase {
public void testGetUsers() { public void testGetUsers() {
final ReservedRealm reservedRealm = new ReservedRealm(mock(Environment.class), Settings.EMPTY, usersStore, final ReservedRealm reservedRealm = new ReservedRealm(mock(Environment.class), Settings.EMPTY, usersStore,
new AnonymousUser(Settings.EMPTY), securityIndex, new ThreadContext(Settings.EMPTY)); new AnonymousUser(Settings.EMPTY), securityIndex, threadPool);
PlainActionFuture<Collection<User>> userFuture = new PlainActionFuture<>(); PlainActionFuture<Collection<User>> userFuture = new PlainActionFuture<>();
reservedRealm.users(userFuture); reservedRealm.users(userFuture);
assertThat(userFuture.actionGet(), assertThat(userFuture.actionGet(),
@ -258,7 +262,7 @@ public class ReservedRealmTests extends ESTestCase {
.build(); .build();
final AnonymousUser anonymousUser = new AnonymousUser(settings); final AnonymousUser anonymousUser = new AnonymousUser(settings);
final ReservedRealm reservedRealm = new ReservedRealm(mock(Environment.class), settings, usersStore, anonymousUser, final ReservedRealm reservedRealm = new ReservedRealm(mock(Environment.class), settings, usersStore, anonymousUser,
securityIndex, new ThreadContext(Settings.EMPTY)); securityIndex, threadPool);
PlainActionFuture<Collection<User>> userFuture = new PlainActionFuture<>(); PlainActionFuture<Collection<User>> userFuture = new PlainActionFuture<>();
reservedRealm.users(userFuture); reservedRealm.users(userFuture);
if (anonymousEnabled) { if (anonymousEnabled) {
@ -275,7 +279,7 @@ public class ReservedRealmTests extends ESTestCase {
ReservedUserInfo userInfo = new ReservedUserInfo(hash, true, false); ReservedUserInfo userInfo = new ReservedUserInfo(hash, true, false);
mockGetAllReservedUserInfo(usersStore, Collections.singletonMap("elastic", userInfo)); mockGetAllReservedUserInfo(usersStore, Collections.singletonMap("elastic", userInfo));
final ReservedRealm reservedRealm = new ReservedRealm(mock(Environment.class), Settings.EMPTY, usersStore, final ReservedRealm reservedRealm = new ReservedRealm(mock(Environment.class), Settings.EMPTY, usersStore,
new AnonymousUser(Settings.EMPTY), securityIndex, new ThreadContext(Settings.EMPTY)); new AnonymousUser(Settings.EMPTY), securityIndex, threadPool);
if (randomBoolean()) { if (randomBoolean()) {
PlainActionFuture<AuthenticationResult> future = new PlainActionFuture<>(); PlainActionFuture<AuthenticationResult> future = new PlainActionFuture<>();
@ -305,7 +309,7 @@ public class ReservedRealmTests extends ESTestCase {
when(securityIndex.indexExists()).thenReturn(true); when(securityIndex.indexExists()).thenReturn(true);
final ReservedRealm reservedRealm = new ReservedRealm(mock(Environment.class), settings, usersStore, final ReservedRealm reservedRealm = new ReservedRealm(mock(Environment.class), settings, usersStore,
new AnonymousUser(Settings.EMPTY), securityIndex, new ThreadContext(Settings.EMPTY)); new AnonymousUser(Settings.EMPTY), securityIndex, threadPool);
PlainActionFuture<AuthenticationResult> listener = new PlainActionFuture<>(); PlainActionFuture<AuthenticationResult> listener = new PlainActionFuture<>();
doAnswer((i) -> { doAnswer((i) -> {
@ -327,7 +331,7 @@ public class ReservedRealmTests extends ESTestCase {
when(securityIndex.indexExists()).thenReturn(true); when(securityIndex.indexExists()).thenReturn(true);
final ReservedRealm reservedRealm = new ReservedRealm(mock(Environment.class), settings, usersStore, final ReservedRealm reservedRealm = new ReservedRealm(mock(Environment.class), settings, usersStore,
new AnonymousUser(Settings.EMPTY), securityIndex, new ThreadContext(Settings.EMPTY)); new AnonymousUser(Settings.EMPTY), securityIndex, threadPool);
PlainActionFuture<AuthenticationResult> listener = new PlainActionFuture<>(); PlainActionFuture<AuthenticationResult> listener = new PlainActionFuture<>();
SecureString password = new SecureString("password".toCharArray()); SecureString password = new SecureString("password".toCharArray());
doAnswer((i) -> { doAnswer((i) -> {
@ -354,7 +358,7 @@ public class ReservedRealmTests extends ESTestCase {
when(securityIndex.indexExists()).thenReturn(false); when(securityIndex.indexExists()).thenReturn(false);
final ReservedRealm reservedRealm = new ReservedRealm(mock(Environment.class), settings, usersStore, final ReservedRealm reservedRealm = new ReservedRealm(mock(Environment.class), settings, usersStore,
new AnonymousUser(Settings.EMPTY), securityIndex, new ThreadContext(Settings.EMPTY)); new AnonymousUser(Settings.EMPTY), securityIndex, threadPool);
PlainActionFuture<AuthenticationResult> listener = new PlainActionFuture<>(); PlainActionFuture<AuthenticationResult> listener = new PlainActionFuture<>();
reservedRealm.doAuthenticate(new UsernamePasswordToken(new ElasticUser(true).principal(), reservedRealm.doAuthenticate(new UsernamePasswordToken(new ElasticUser(true).principal(),
@ -372,7 +376,7 @@ public class ReservedRealmTests extends ESTestCase {
when(securityIndex.indexExists()).thenReturn(true); when(securityIndex.indexExists()).thenReturn(true);
final ReservedRealm reservedRealm = new ReservedRealm(mock(Environment.class), settings, usersStore, final ReservedRealm reservedRealm = new ReservedRealm(mock(Environment.class), settings, usersStore,
new AnonymousUser(Settings.EMPTY), securityIndex, new ThreadContext(Settings.EMPTY)); new AnonymousUser(Settings.EMPTY), securityIndex, threadPool);
PlainActionFuture<AuthenticationResult> listener = new PlainActionFuture<>(); PlainActionFuture<AuthenticationResult> listener = new PlainActionFuture<>();
final String principal = randomFrom(KibanaUser.NAME, LogstashSystemUser.NAME, BeatsSystemUser.NAME); final String principal = randomFrom(KibanaUser.NAME, LogstashSystemUser.NAME, BeatsSystemUser.NAME);
@ -394,7 +398,7 @@ public class ReservedRealmTests extends ESTestCase {
when(securityIndex.indexExists()).thenReturn(false); when(securityIndex.indexExists()).thenReturn(false);
final ReservedRealm reservedRealm = new ReservedRealm(mock(Environment.class), settings, usersStore, final ReservedRealm reservedRealm = new ReservedRealm(mock(Environment.class), settings, usersStore,
new AnonymousUser(Settings.EMPTY), securityIndex, new ThreadContext(Settings.EMPTY)); new AnonymousUser(Settings.EMPTY), securityIndex, threadPool);
PlainActionFuture<AuthenticationResult> listener = new PlainActionFuture<>(); PlainActionFuture<AuthenticationResult> listener = new PlainActionFuture<>();
final String principal = randomFrom(KibanaUser.NAME, LogstashSystemUser.NAME, BeatsSystemUser.NAME); final String principal = randomFrom(KibanaUser.NAME, LogstashSystemUser.NAME, BeatsSystemUser.NAME);

View File

@ -11,6 +11,7 @@ import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.env.TestEnvironment; import org.elasticsearch.env.TestEnvironment;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.watcher.ResourceWatcherService; import org.elasticsearch.watcher.ResourceWatcherService;
import org.elasticsearch.xpack.core.security.authc.AuthenticationResult; import org.elasticsearch.xpack.core.security.authc.AuthenticationResult;
import org.elasticsearch.xpack.core.security.authc.RealmConfig; import org.elasticsearch.xpack.core.security.authc.RealmConfig;
@ -50,20 +51,26 @@ public class FileRealmTests extends ESTestCase {
private FileUserPasswdStore userPasswdStore; private FileUserPasswdStore userPasswdStore;
private FileUserRolesStore userRolesStore; private FileUserRolesStore userRolesStore;
private Settings globalSettings; private Settings globalSettings;
private ThreadPool threadPool;
private ThreadContext threadContext;
@Before @Before
public void init() throws Exception { public void init() throws Exception {
userPasswdStore = mock(FileUserPasswdStore.class); userPasswdStore = mock(FileUserPasswdStore.class);
userRolesStore = mock(FileUserRolesStore.class); userRolesStore = mock(FileUserRolesStore.class);
globalSettings = Settings.builder().put("path.home", createTempDir()).build(); globalSettings = Settings.builder().put("path.home", createTempDir()).build();
threadPool = mock(ThreadPool.class);
threadContext = new ThreadContext(globalSettings);
when(threadPool.getThreadContext()).thenReturn(threadContext);
} }
public void testAuthenticate() throws Exception { public void testAuthenticate() throws Exception {
when(userPasswdStore.verifyPassword(eq("user1"), eq(new SecureString("test123")), any(Supplier.class))) when(userPasswdStore.verifyPassword(eq("user1"), eq(new SecureString("test123")), any(Supplier.class)))
.thenAnswer(VERIFY_PASSWORD_ANSWER); .thenAnswer(VERIFY_PASSWORD_ANSWER);
when(userRolesStore.roles("user1")).thenReturn(new String[] { "role1", "role2" }); when(userRolesStore.roles("user1")).thenReturn(new String[] { "role1", "role2" });
RealmConfig config = new RealmConfig("file-test", Settings.EMPTY, globalSettings, TestEnvironment.newEnvironment(globalSettings), new ThreadContext(globalSettings)); RealmConfig config = new RealmConfig("file-test", Settings.EMPTY, globalSettings, TestEnvironment.newEnvironment(globalSettings),
FileRealm realm = new FileRealm(config, userPasswdStore, userRolesStore); threadContext);
FileRealm realm = new FileRealm(config, userPasswdStore, userRolesStore, threadPool);
PlainActionFuture<AuthenticationResult> future = new PlainActionFuture<>(); PlainActionFuture<AuthenticationResult> future = new PlainActionFuture<>();
realm.authenticate(new UsernamePasswordToken("user1", new SecureString("test123")), future); realm.authenticate(new UsernamePasswordToken("user1", new SecureString("test123")), future);
final AuthenticationResult result = future.actionGet(); final AuthenticationResult result = future.actionGet();
@ -80,11 +87,12 @@ public class FileRealmTests extends ESTestCase {
Settings settings = Settings.builder() Settings settings = Settings.builder()
.put("cache.hash_algo", Hasher.values()[randomIntBetween(0, Hasher.values().length - 1)].name().toLowerCase(Locale.ROOT)) .put("cache.hash_algo", Hasher.values()[randomIntBetween(0, Hasher.values().length - 1)].name().toLowerCase(Locale.ROOT))
.build(); .build();
RealmConfig config = new RealmConfig("file-test", settings, globalSettings, TestEnvironment.newEnvironment(globalSettings), new ThreadContext(globalSettings)); RealmConfig config = new RealmConfig("file-test", settings, globalSettings, TestEnvironment.newEnvironment(globalSettings),
threadContext);
when(userPasswdStore.verifyPassword(eq("user1"), eq(new SecureString("test123")), any(Supplier.class))) when(userPasswdStore.verifyPassword(eq("user1"), eq(new SecureString("test123")), any(Supplier.class)))
.thenAnswer(VERIFY_PASSWORD_ANSWER); .thenAnswer(VERIFY_PASSWORD_ANSWER);
when(userRolesStore.roles("user1")).thenReturn(new String[]{"role1", "role2"}); when(userRolesStore.roles("user1")).thenReturn(new String[]{"role1", "role2"});
FileRealm realm = new FileRealm(config, userPasswdStore, userRolesStore); FileRealm realm = new FileRealm(config, userPasswdStore, userRolesStore, threadPool);
PlainActionFuture<AuthenticationResult> future = new PlainActionFuture<>(); PlainActionFuture<AuthenticationResult> future = new PlainActionFuture<>();
realm.authenticate(new UsernamePasswordToken("user1", new SecureString("test123")), future); realm.authenticate(new UsernamePasswordToken("user1", new SecureString("test123")), future);
User user1 = future.actionGet().getUser(); User user1 = future.actionGet().getUser();
@ -95,13 +103,14 @@ public class FileRealmTests extends ESTestCase {
} }
public void testAuthenticateCachingRefresh() throws Exception { public void testAuthenticateCachingRefresh() throws Exception {
RealmConfig config = new RealmConfig("file-test", Settings.EMPTY, globalSettings, TestEnvironment.newEnvironment(globalSettings), new ThreadContext(globalSettings)); RealmConfig config = new RealmConfig("file-test", Settings.EMPTY, globalSettings, TestEnvironment.newEnvironment(globalSettings),
threadContext);
userPasswdStore = spy(new UserPasswdStore(config)); userPasswdStore = spy(new UserPasswdStore(config));
userRolesStore = spy(new UserRolesStore(config)); userRolesStore = spy(new UserRolesStore(config));
when(userPasswdStore.verifyPassword(eq("user1"), eq(new SecureString("test123")), any(Supplier.class))) when(userPasswdStore.verifyPassword(eq("user1"), eq(new SecureString("test123")), any(Supplier.class)))
.thenAnswer(VERIFY_PASSWORD_ANSWER); .thenAnswer(VERIFY_PASSWORD_ANSWER);
doReturn(new String[] { "role1", "role2" }).when(userRolesStore).roles("user1"); doReturn(new String[] { "role1", "role2" }).when(userRolesStore).roles("user1");
FileRealm realm = new FileRealm(config, userPasswdStore, userRolesStore); FileRealm realm = new FileRealm(config, userPasswdStore, userRolesStore, threadPool);
PlainActionFuture<AuthenticationResult> future = new PlainActionFuture<>(); PlainActionFuture<AuthenticationResult> future = new PlainActionFuture<>();
realm.authenticate(new UsernamePasswordToken("user1", new SecureString("test123")), future); realm.authenticate(new UsernamePasswordToken("user1", new SecureString("test123")), future);
User user1 = future.actionGet().getUser(); User user1 = future.actionGet().getUser();
@ -134,11 +143,12 @@ public class FileRealmTests extends ESTestCase {
} }
public void testToken() throws Exception { public void testToken() throws Exception {
RealmConfig config = new RealmConfig("file-test", Settings.EMPTY, globalSettings, TestEnvironment.newEnvironment(globalSettings), new ThreadContext(globalSettings)); RealmConfig config = new RealmConfig("file-test", Settings.EMPTY, globalSettings, TestEnvironment.newEnvironment(globalSettings),
threadContext);
when(userPasswdStore.verifyPassword(eq("user1"), eq(new SecureString("test123")), any(Supplier.class))) when(userPasswdStore.verifyPassword(eq("user1"), eq(new SecureString("test123")), any(Supplier.class)))
.thenAnswer(VERIFY_PASSWORD_ANSWER); .thenAnswer(VERIFY_PASSWORD_ANSWER);
when(userRolesStore.roles("user1")).thenReturn(new String[]{"role1", "role2"}); when(userRolesStore.roles("user1")).thenReturn(new String[]{"role1", "role2"});
FileRealm realm = new FileRealm(config, userPasswdStore, userRolesStore); FileRealm realm = new FileRealm(config, userPasswdStore, userRolesStore, threadPool);
ThreadContext threadContext = new ThreadContext(Settings.EMPTY); ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
UsernamePasswordToken.putTokenHeader(threadContext, new UsernamePasswordToken("user1", new SecureString("test123"))); UsernamePasswordToken.putTokenHeader(threadContext, new UsernamePasswordToken("user1", new SecureString("test123")));
@ -153,8 +163,9 @@ public class FileRealmTests extends ESTestCase {
public void testLookup() throws Exception { public void testLookup() throws Exception {
when(userPasswdStore.userExists("user1")).thenReturn(true); when(userPasswdStore.userExists("user1")).thenReturn(true);
when(userRolesStore.roles("user1")).thenReturn(new String[] { "role1", "role2" }); when(userRolesStore.roles("user1")).thenReturn(new String[] { "role1", "role2" });
RealmConfig config = new RealmConfig("file-test", Settings.EMPTY, globalSettings, TestEnvironment.newEnvironment(globalSettings), new ThreadContext(globalSettings)); RealmConfig config = new RealmConfig("file-test", Settings.EMPTY, globalSettings, TestEnvironment.newEnvironment(globalSettings),
FileRealm realm = new FileRealm(config, userPasswdStore, userRolesStore); threadContext);
FileRealm realm = new FileRealm(config, userPasswdStore, userRolesStore, threadPool);
PlainActionFuture<User> future = new PlainActionFuture<>(); PlainActionFuture<User> future = new PlainActionFuture<>();
realm.lookupUser("user1", future); realm.lookupUser("user1", future);
@ -170,8 +181,9 @@ public class FileRealmTests extends ESTestCase {
public void testLookupCaching() throws Exception { public void testLookupCaching() throws Exception {
when(userPasswdStore.userExists("user1")).thenReturn(true); when(userPasswdStore.userExists("user1")).thenReturn(true);
when(userRolesStore.roles("user1")).thenReturn(new String[] { "role1", "role2" }); when(userRolesStore.roles("user1")).thenReturn(new String[] { "role1", "role2" });
RealmConfig config = new RealmConfig("file-test", Settings.EMPTY, globalSettings, TestEnvironment.newEnvironment(globalSettings), new ThreadContext(globalSettings)); RealmConfig config = new RealmConfig("file-test", Settings.EMPTY, globalSettings, TestEnvironment.newEnvironment(globalSettings),
FileRealm realm = new FileRealm(config, userPasswdStore, userRolesStore); threadContext);
FileRealm realm = new FileRealm(config, userPasswdStore, userRolesStore, threadPool);
PlainActionFuture<User> future = new PlainActionFuture<>(); PlainActionFuture<User> future = new PlainActionFuture<>();
realm.lookupUser("user1", future); realm.lookupUser("user1", future);
@ -185,12 +197,13 @@ public class FileRealmTests extends ESTestCase {
} }
public void testLookupCachingWithRefresh() throws Exception { public void testLookupCachingWithRefresh() throws Exception {
RealmConfig config = new RealmConfig("file-test", Settings.EMPTY, globalSettings, TestEnvironment.newEnvironment(globalSettings), new ThreadContext(globalSettings)); RealmConfig config = new RealmConfig("file-test", Settings.EMPTY, globalSettings, TestEnvironment.newEnvironment(globalSettings),
threadContext);
userPasswdStore = spy(new UserPasswdStore(config)); userPasswdStore = spy(new UserPasswdStore(config));
userRolesStore = spy(new UserRolesStore(config)); userRolesStore = spy(new UserRolesStore(config));
doReturn(true).when(userPasswdStore).userExists("user1"); doReturn(true).when(userPasswdStore).userExists("user1");
doReturn(new String[] { "role1", "role2" }).when(userRolesStore).roles("user1"); doReturn(new String[] { "role1", "role2" }).when(userRolesStore).roles("user1");
FileRealm realm = new FileRealm(config, userPasswdStore, userRolesStore); FileRealm realm = new FileRealm(config, userPasswdStore, userRolesStore, threadPool);
PlainActionFuture<User> future = new PlainActionFuture<>(); PlainActionFuture<User> future = new PlainActionFuture<>();
realm.lookupUser("user1", future); realm.lookupUser("user1", future);
User user1 = future.actionGet(); User user1 = future.actionGet();
@ -231,8 +244,9 @@ public class FileRealmTests extends ESTestCase {
int order = randomIntBetween(0, 10); int order = randomIntBetween(0, 10);
settings.put("order", order); settings.put("order", order);
RealmConfig config = new RealmConfig("file-realm", settings.build(), globalSettings, TestEnvironment.newEnvironment(globalSettings), new ThreadContext(globalSettings)); RealmConfig config = new RealmConfig("file-realm", settings.build(), globalSettings, TestEnvironment.newEnvironment(globalSettings),
FileRealm realm = new FileRealm(config, userPasswdStore, userRolesStore); threadContext);
FileRealm realm = new FileRealm(config, userPasswdStore, userRolesStore, threadPool);
Map<String, Object> usage = realm.usageStats(); Map<String, Object> usage = realm.usageStats();
assertThat(usage, is(notNullValue())); assertThat(usage, is(notNullValue()));

View File

@ -14,6 +14,8 @@ import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.env.TestEnvironment; import org.elasticsearch.env.TestEnvironment;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.SecuritySettingsSourceField; import org.elasticsearch.test.SecuritySettingsSourceField;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.security.authc.AuthenticationResult; import org.elasticsearch.xpack.core.security.authc.AuthenticationResult;
import org.elasticsearch.xpack.core.security.authc.Realm; import org.elasticsearch.xpack.core.security.authc.Realm;
import org.elasticsearch.xpack.core.security.authc.RealmConfig; import org.elasticsearch.xpack.core.security.authc.RealmConfig;
@ -22,6 +24,7 @@ import org.elasticsearch.xpack.core.security.authc.support.CachingUsernamePasswo
import org.elasticsearch.xpack.core.security.authc.support.Hasher; import org.elasticsearch.xpack.core.security.authc.support.Hasher;
import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken; import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken;
import org.elasticsearch.xpack.core.security.user.User; import org.elasticsearch.xpack.core.security.user.User;
import org.junit.After;
import org.junit.Before; import org.junit.Before;
import java.util.ArrayList; import java.util.ArrayList;
@ -42,10 +45,19 @@ import static org.hamcrest.Matchers.sameInstance;
public class CachingUsernamePasswordRealmTests extends ESTestCase { public class CachingUsernamePasswordRealmTests extends ESTestCase {
private Settings globalSettings; private Settings globalSettings;
private ThreadPool threadPool;
@Before @Before
public void setup() { public void setup() {
globalSettings = Settings.builder().put("path.home", createTempDir()).build(); globalSettings = Settings.builder().put("path.home", createTempDir()).build();
threadPool = new TestThreadPool("caching username password realm tests");
}
@After
public void stop() throws InterruptedException {
if (threadPool != null) {
terminate(threadPool);
}
} }
public void testSettings() throws Exception { public void testSettings() throws Exception {
@ -61,7 +73,7 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase {
RealmConfig config = new RealmConfig("test_realm", settings, globalSettings, TestEnvironment.newEnvironment(globalSettings), RealmConfig config = new RealmConfig("test_realm", settings, globalSettings, TestEnvironment.newEnvironment(globalSettings),
new ThreadContext(Settings.EMPTY)); new ThreadContext(Settings.EMPTY));
CachingUsernamePasswordRealm realm = new CachingUsernamePasswordRealm("test", config) { CachingUsernamePasswordRealm realm = new CachingUsernamePasswordRealm("test", config, threadPool) {
@Override @Override
protected void doAuthenticate(UsernamePasswordToken token, ActionListener<AuthenticationResult> listener) { protected void doAuthenticate(UsernamePasswordToken token, ActionListener<AuthenticationResult> listener) {
listener.onResponse(AuthenticationResult.success(new User("username", new String[]{"r1", "r2", "r3"}))); listener.onResponse(AuthenticationResult.success(new User("username", new String[]{"r1", "r2", "r3"})));
@ -77,7 +89,7 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase {
} }
public void testAuthCache() { public void testAuthCache() {
AlwaysAuthenticateCachingRealm realm = new AlwaysAuthenticateCachingRealm(globalSettings); AlwaysAuthenticateCachingRealm realm = new AlwaysAuthenticateCachingRealm(globalSettings, threadPool);
SecureString pass = new SecureString("pass"); SecureString pass = new SecureString("pass");
PlainActionFuture<AuthenticationResult> future = new PlainActionFuture<>(); PlainActionFuture<AuthenticationResult> future = new PlainActionFuture<>();
realm.authenticate(new UsernamePasswordToken("a", pass), future); realm.authenticate(new UsernamePasswordToken("a", pass), future);
@ -106,7 +118,7 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase {
} }
public void testLookupCache() { public void testLookupCache() {
AlwaysAuthenticateCachingRealm realm = new AlwaysAuthenticateCachingRealm(globalSettings); AlwaysAuthenticateCachingRealm realm = new AlwaysAuthenticateCachingRealm(globalSettings, threadPool);
PlainActionFuture<User> future = new PlainActionFuture<>(); PlainActionFuture<User> future = new PlainActionFuture<>();
realm.lookupUser("a", future); realm.lookupUser("a", future);
future.actionGet(); future.actionGet();
@ -133,7 +145,7 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase {
} }
public void testLookupAndAuthCache() { public void testLookupAndAuthCache() {
AlwaysAuthenticateCachingRealm realm = new AlwaysAuthenticateCachingRealm(globalSettings); AlwaysAuthenticateCachingRealm realm = new AlwaysAuthenticateCachingRealm(globalSettings, threadPool);
// lookup first // lookup first
PlainActionFuture<User> lookupFuture = new PlainActionFuture<>(); PlainActionFuture<User> lookupFuture = new PlainActionFuture<>();
realm.lookupUser("a", lookupFuture); realm.lookupUser("a", lookupFuture);
@ -172,7 +184,7 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase {
} }
public void testCacheChangePassword() { public void testCacheChangePassword() {
AlwaysAuthenticateCachingRealm realm = new AlwaysAuthenticateCachingRealm(globalSettings); AlwaysAuthenticateCachingRealm realm = new AlwaysAuthenticateCachingRealm(globalSettings, threadPool);
String user = "testUser"; String user = "testUser";
SecureString pass1 = new SecureString("pass"); SecureString pass1 = new SecureString("pass");
@ -198,7 +210,7 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase {
} }
public void testCacheDisabledUser() { public void testCacheDisabledUser() {
AlwaysAuthenticateCachingRealm realm = new AlwaysAuthenticateCachingRealm(globalSettings); AlwaysAuthenticateCachingRealm realm = new AlwaysAuthenticateCachingRealm(globalSettings, threadPool);
realm.setUsersEnabled(false); realm.setUsersEnabled(false);
String user = "testUser"; String user = "testUser";
@ -233,7 +245,7 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase {
.build(); .build();
RealmConfig config = new RealmConfig("test_cache_ttl", settings, globalSettings, TestEnvironment.newEnvironment(globalSettings), RealmConfig config = new RealmConfig("test_cache_ttl", settings, globalSettings, TestEnvironment.newEnvironment(globalSettings),
new ThreadContext(Settings.EMPTY)); new ThreadContext(Settings.EMPTY));
AlwaysAuthenticateCachingRealm realm = new AlwaysAuthenticateCachingRealm(config); AlwaysAuthenticateCachingRealm realm = new AlwaysAuthenticateCachingRealm(config, threadPool);
final UsernamePasswordToken authToken = new UsernamePasswordToken("the-user", new SecureString("the-password")); final UsernamePasswordToken authToken = new UsernamePasswordToken("the-user", new SecureString("the-password"));
@ -262,7 +274,7 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase {
.build(); .build();
RealmConfig config = new RealmConfig("test_cache_ttl", settings, globalSettings, TestEnvironment.newEnvironment(globalSettings), RealmConfig config = new RealmConfig("test_cache_ttl", settings, globalSettings, TestEnvironment.newEnvironment(globalSettings),
new ThreadContext(Settings.EMPTY)); new ThreadContext(Settings.EMPTY));
AlwaysAuthenticateCachingRealm realm = new AlwaysAuthenticateCachingRealm(config); AlwaysAuthenticateCachingRealm realm = new AlwaysAuthenticateCachingRealm(config, threadPool);
final UsernamePasswordToken authToken = new UsernamePasswordToken("the-user", new SecureString("the-password")); final UsernamePasswordToken authToken = new UsernamePasswordToken("the-user", new SecureString("the-password"));
PlainActionFuture<AuthenticationResult> future = new PlainActionFuture<>(); PlainActionFuture<AuthenticationResult> future = new PlainActionFuture<>();
@ -304,13 +316,13 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase {
} }
public void testAuthenticateContract() throws Exception { public void testAuthenticateContract() throws Exception {
Realm realm = new FailingAuthenticationRealm(Settings.EMPTY, globalSettings); Realm realm = new FailingAuthenticationRealm(Settings.EMPTY, globalSettings, threadPool);
PlainActionFuture<AuthenticationResult> future = new PlainActionFuture<>(); PlainActionFuture<AuthenticationResult> future = new PlainActionFuture<>();
realm.authenticate(new UsernamePasswordToken("user", new SecureString("pass")), future); realm.authenticate(new UsernamePasswordToken("user", new SecureString("pass")), future);
User user = future.actionGet().getUser(); User user = future.actionGet().getUser();
assertThat(user, nullValue()); assertThat(user, nullValue());
realm = new ThrowingAuthenticationRealm(Settings.EMPTY, globalSettings); realm = new ThrowingAuthenticationRealm(Settings.EMPTY, globalSettings, threadPool);
future = new PlainActionFuture<>(); future = new PlainActionFuture<>();
realm.authenticate(new UsernamePasswordToken("user", new SecureString("pass")), future); realm.authenticate(new UsernamePasswordToken("user", new SecureString("pass")), future);
RuntimeException e = expectThrows(RuntimeException.class, future::actionGet); RuntimeException e = expectThrows(RuntimeException.class, future::actionGet);
@ -318,19 +330,85 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase {
} }
public void testLookupContract() throws Exception { public void testLookupContract() throws Exception {
Realm realm = new FailingAuthenticationRealm(Settings.EMPTY, globalSettings); Realm realm = new FailingAuthenticationRealm(Settings.EMPTY, globalSettings, threadPool);
PlainActionFuture<User> future = new PlainActionFuture<>(); PlainActionFuture<User> future = new PlainActionFuture<>();
realm.lookupUser("user", future); realm.lookupUser("user", future);
User user = future.actionGet(); User user = future.actionGet();
assertThat(user, nullValue()); assertThat(user, nullValue());
realm = new ThrowingAuthenticationRealm(Settings.EMPTY, globalSettings); realm = new ThrowingAuthenticationRealm(Settings.EMPTY, globalSettings, threadPool);
future = new PlainActionFuture<>(); future = new PlainActionFuture<>();
realm.lookupUser("user", future); realm.lookupUser("user", future);
RuntimeException e = expectThrows(RuntimeException.class, future::actionGet); RuntimeException e = expectThrows(RuntimeException.class, future::actionGet);
assertThat(e.getMessage(), containsString("lookup exception")); assertThat(e.getMessage(), containsString("lookup exception"));
} }
public void testSingleAuthPerUserLimit() throws Exception {
final String username = "username";
final SecureString password = SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING;
final AtomicInteger authCounter = new AtomicInteger(0);
final String passwordHash = new String(Hasher.BCRYPT.hash(password));
RealmConfig config = new RealmConfig("test_realm", Settings.EMPTY, globalSettings, TestEnvironment.newEnvironment(globalSettings),
new ThreadContext(Settings.EMPTY));
final CachingUsernamePasswordRealm realm = new CachingUsernamePasswordRealm("test", config, threadPool) {
@Override
protected void doAuthenticate(UsernamePasswordToken token, ActionListener<AuthenticationResult> listener) {
authCounter.incrementAndGet();
// do something slow
if (BCrypt.checkpw(token.credentials(), passwordHash)) {
listener.onResponse(AuthenticationResult.success(new User(username, new String[]{"r1", "r2", "r3"})));
} else {
listener.onFailure(new IllegalStateException("password auth should never fail"));
}
}
@Override
protected void doLookupUser(String username, ActionListener<User> listener) {
listener.onFailure(new UnsupportedOperationException("this method should not be called"));
}
};
final int numberOfProcessors = Runtime.getRuntime().availableProcessors();
final int numberOfThreads = scaledRandomIntBetween((numberOfProcessors + 1) / 2, numberOfProcessors * 3);
final int numberOfIterations = scaledRandomIntBetween(20, 100);
final CountDownLatch latch = new CountDownLatch(1 + numberOfThreads);
List<Thread> threads = new ArrayList<>(numberOfThreads);
for (int i = 0; i < numberOfThreads; i++) {
threads.add(new Thread(() -> {
try {
latch.countDown();
latch.await();
for (int i1 = 0; i1 < numberOfIterations; i1++) {
UsernamePasswordToken token = new UsernamePasswordToken(username, password);
realm.authenticate(token, ActionListener.wrap((result) -> {
if (result.isAuthenticated() == false) {
throw new IllegalStateException("proper password led to an unauthenticated result: " + result);
}
}, (e) -> {
logger.error("caught exception", e);
fail("unexpected exception - " + e);
}));
}
} catch (InterruptedException e) {
logger.error("thread was interrupted", e);
Thread.currentThread().interrupt();
}
}));
}
for (Thread thread : threads) {
thread.start();
}
latch.countDown();
for (Thread thread : threads) {
thread.join();
}
assertEquals(1, authCounter.get());
}
public void testCacheConcurrency() throws Exception { public void testCacheConcurrency() throws Exception {
final String username = "username"; final String username = "username";
final SecureString password = SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING; final SecureString password = SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING;
@ -339,7 +417,7 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase {
final String passwordHash = new String(Hasher.BCRYPT.hash(password)); final String passwordHash = new String(Hasher.BCRYPT.hash(password));
RealmConfig config = new RealmConfig("test_realm", Settings.EMPTY, globalSettings, TestEnvironment.newEnvironment(globalSettings), RealmConfig config = new RealmConfig("test_realm", Settings.EMPTY, globalSettings, TestEnvironment.newEnvironment(globalSettings),
new ThreadContext(Settings.EMPTY)); new ThreadContext(Settings.EMPTY));
final CachingUsernamePasswordRealm realm = new CachingUsernamePasswordRealm("test", config) { final CachingUsernamePasswordRealm realm = new CachingUsernamePasswordRealm("test", config, threadPool) {
@Override @Override
protected void doAuthenticate(UsernamePasswordToken token, ActionListener<AuthenticationResult> listener) { protected void doAuthenticate(UsernamePasswordToken token, ActionListener<AuthenticationResult> listener) {
// do something slow // do something slow
@ -356,37 +434,37 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase {
} }
}; };
final CountDownLatch latch = new CountDownLatch(1);
final int numberOfProcessors = Runtime.getRuntime().availableProcessors(); final int numberOfProcessors = Runtime.getRuntime().availableProcessors();
final int numberOfThreads = scaledRandomIntBetween((numberOfProcessors + 1) / 2, numberOfProcessors * 3); final int numberOfThreads = scaledRandomIntBetween((numberOfProcessors + 1) / 2, numberOfProcessors * 3);
final int numberOfIterations = scaledRandomIntBetween(20, 100); final int numberOfIterations = scaledRandomIntBetween(20, 100);
List<Thread> threads = new ArrayList<>(); final CountDownLatch latch = new CountDownLatch(1 + numberOfThreads);
List<Thread> threads = new ArrayList<>(numberOfThreads);
for (int i = 0; i < numberOfThreads; i++) { for (int i = 0; i < numberOfThreads; i++) {
final boolean invalidPassword = randomBoolean(); final boolean invalidPassword = randomBoolean();
threads.add(new Thread() { threads.add(new Thread(() -> {
@Override try {
public void run() { latch.countDown();
try { latch.await();
latch.await(); for (int i1 = 0; i1 < numberOfIterations; i1++) {
for (int i = 0; i < numberOfIterations; i++) { UsernamePasswordToken token = new UsernamePasswordToken(username, invalidPassword ? randomPassword : password);
UsernamePasswordToken token = new UsernamePasswordToken(username, invalidPassword ? randomPassword : password);
realm.authenticate(token, ActionListener.wrap((result) -> { realm.authenticate(token, ActionListener.wrap((result) -> {
if (invalidPassword && result.isAuthenticated()) { if (invalidPassword && result.isAuthenticated()) {
throw new RuntimeException("invalid password led to an authenticated user: " + result); throw new RuntimeException("invalid password led to an authenticated user: " + result);
} else if (invalidPassword == false && result.isAuthenticated() == false) { } else if (invalidPassword == false && result.isAuthenticated() == false) {
throw new RuntimeException("proper password led to an unauthenticated result: " + result); throw new RuntimeException("proper password led to an unauthenticated result: " + result);
} }
}, (e) -> { }, (e) -> {
logger.error("caught exception", e); logger.error("caught exception", e);
fail("unexpected exception - " + e); fail("unexpected exception - " + e);
})); }));
}
} catch (InterruptedException e) {
} }
} catch (InterruptedException e) {
logger.error("thread was interrupted", e);
Thread.currentThread().interrupt();
} }
}); }));
} }
for (Thread thread : threads) { for (Thread thread : threads) {
@ -400,10 +478,11 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase {
public void testUserLookupConcurrency() throws Exception { public void testUserLookupConcurrency() throws Exception {
final String username = "username"; final String username = "username";
final AtomicInteger lookupCounter = new AtomicInteger(0);
RealmConfig config = new RealmConfig("test_realm", Settings.EMPTY, globalSettings, TestEnvironment.newEnvironment(globalSettings), RealmConfig config = new RealmConfig("test_realm", Settings.EMPTY, globalSettings, TestEnvironment.newEnvironment(globalSettings),
new ThreadContext(Settings.EMPTY)); new ThreadContext(Settings.EMPTY));
final CachingUsernamePasswordRealm realm = new CachingUsernamePasswordRealm("test", config) { final CachingUsernamePasswordRealm realm = new CachingUsernamePasswordRealm("test", config, threadPool) {
@Override @Override
protected void doAuthenticate(UsernamePasswordToken token, ActionListener<AuthenticationResult> listener) { protected void doAuthenticate(UsernamePasswordToken token, ActionListener<AuthenticationResult> listener) {
listener.onFailure(new UnsupportedOperationException("authenticate should not be called!")); listener.onFailure(new UnsupportedOperationException("authenticate should not be called!"));
@ -411,36 +490,37 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase {
@Override @Override
protected void doLookupUser(String username, ActionListener<User> listener) { protected void doLookupUser(String username, ActionListener<User> listener) {
lookupCounter.incrementAndGet();
listener.onResponse(new User(username, new String[]{"r1", "r2", "r3"})); listener.onResponse(new User(username, new String[]{"r1", "r2", "r3"}));
} }
}; };
final CountDownLatch latch = new CountDownLatch(1);
final int numberOfProcessors = Runtime.getRuntime().availableProcessors(); final int numberOfProcessors = Runtime.getRuntime().availableProcessors();
final int numberOfThreads = scaledRandomIntBetween(numberOfProcessors, numberOfProcessors * 3); final int numberOfThreads = scaledRandomIntBetween(numberOfProcessors, numberOfProcessors * 3);
final int numberOfIterations = scaledRandomIntBetween(10000, 100000); final int numberOfIterations = scaledRandomIntBetween(10000, 100000);
List<Thread> threads = new ArrayList<>(); final CountDownLatch latch = new CountDownLatch(1 + numberOfThreads);
List<Thread> threads = new ArrayList<>(numberOfThreads);
for (int i = 0; i < numberOfThreads; i++) { for (int i = 0; i < numberOfThreads; i++) {
threads.add(new Thread() { threads.add(new Thread(() -> {
@Override try {
public void run() { latch.countDown();
try { latch.await();
latch.await(); for (int i1 = 0; i1 < numberOfIterations; i1++) {
for (int i = 0; i < numberOfIterations; i++) { realm.lookupUser(username, ActionListener.wrap((user) -> {
realm.lookupUser(username, ActionListener.wrap((user) -> { if (user == null) {
if (user == null) { throw new RuntimeException("failed to lookup user");
throw new RuntimeException("failed to lookup user"); }
} }, (e) -> {
}, (e) -> { logger.error("caught exception", e);
logger.error("caught exception", e); fail("unexpected exception");
fail("unexpected exception"); }));
}));
}
} catch (InterruptedException e) {
} }
} catch (InterruptedException e) {
logger.error("thread was interrupted", e);
Thread.currentThread().interrupt();
} }
}); }));
} }
for (Thread thread : threads) { for (Thread thread : threads) {
@ -450,13 +530,14 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase {
for (Thread thread : threads) { for (Thread thread : threads) {
thread.join(); thread.join();
} }
assertEquals(1, lookupCounter.get());
} }
static class FailingAuthenticationRealm extends CachingUsernamePasswordRealm { static class FailingAuthenticationRealm extends CachingUsernamePasswordRealm {
FailingAuthenticationRealm(Settings settings, Settings global) { FailingAuthenticationRealm(Settings settings, Settings global, ThreadPool threadPool) {
super("failing", new RealmConfig("failing-test", settings, global, TestEnvironment.newEnvironment(global), super("failing", new RealmConfig("failing-test", settings, global, TestEnvironment.newEnvironment(global),
new ThreadContext(Settings.EMPTY))); threadPool.getThreadContext()), threadPool);
} }
@Override @Override
@ -472,9 +553,9 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase {
static class ThrowingAuthenticationRealm extends CachingUsernamePasswordRealm { static class ThrowingAuthenticationRealm extends CachingUsernamePasswordRealm {
ThrowingAuthenticationRealm(Settings settings, Settings globalSettings) { ThrowingAuthenticationRealm(Settings settings, Settings globalSettings, ThreadPool threadPool) {
super("throwing", new RealmConfig("throwing-test", settings, globalSettings, TestEnvironment.newEnvironment(globalSettings), super("throwing", new RealmConfig("throwing-test", settings, globalSettings, TestEnvironment.newEnvironment(globalSettings),
new ThreadContext(Settings.EMPTY))); threadPool.getThreadContext()), threadPool);
} }
@Override @Override
@ -495,13 +576,13 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase {
private boolean usersEnabled = true; private boolean usersEnabled = true;
AlwaysAuthenticateCachingRealm(Settings globalSettings) { AlwaysAuthenticateCachingRealm(Settings globalSettings, ThreadPool threadPool) {
this(new RealmConfig("always-test", Settings.EMPTY, globalSettings, TestEnvironment.newEnvironment(globalSettings), this(new RealmConfig("always-test", Settings.EMPTY, globalSettings, TestEnvironment.newEnvironment(globalSettings),
new ThreadContext(Settings.EMPTY))); threadPool.getThreadContext()), threadPool);
} }
AlwaysAuthenticateCachingRealm(RealmConfig config) { AlwaysAuthenticateCachingRealm(RealmConfig config, ThreadPool threadPool) {
super("always", config); super("always", config, threadPool);
} }
void setUsersEnabled(boolean usersEnabled) { void setUsersEnabled(boolean usersEnabled) {
@ -527,9 +608,9 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase {
public final AtomicInteger authInvocationCounter = new AtomicInteger(0); public final AtomicInteger authInvocationCounter = new AtomicInteger(0);
public final AtomicInteger lookupInvocationCounter = new AtomicInteger(0); public final AtomicInteger lookupInvocationCounter = new AtomicInteger(0);
LookupNotSupportedRealm(Settings globalSettings) { LookupNotSupportedRealm(Settings globalSettings, ThreadPool threadPool) {
super("lookup", new RealmConfig("lookup-notsupported-test", Settings.EMPTY, globalSettings, super("lookup", new RealmConfig("lookup-notsupported-test", Settings.EMPTY, globalSettings,
TestEnvironment.newEnvironment(globalSettings), new ThreadContext(Settings.EMPTY))); TestEnvironment.newEnvironment(globalSettings), threadPool.getThreadContext()), threadPool);
} }
@Override @Override

View File

@ -198,7 +198,7 @@ public class NativeRoleMappingStoreTests extends ESTestCase {
final Environment env = TestEnvironment.newEnvironment(settings); final Environment env = TestEnvironment.newEnvironment(settings);
final RealmConfig realmConfig = new RealmConfig(getTestName(), Settings.EMPTY, settings, env, threadContext); final RealmConfig realmConfig = new RealmConfig(getTestName(), Settings.EMPTY, settings, env, threadContext);
final CachingUsernamePasswordRealm mockRealm = new CachingUsernamePasswordRealm("test", realmConfig) { final CachingUsernamePasswordRealm mockRealm = new CachingUsernamePasswordRealm("test", realmConfig, threadPool) {
@Override @Override
protected void doAuthenticate(UsernamePasswordToken token, ActionListener<AuthenticationResult> listener) { protected void doAuthenticate(UsernamePasswordToken token, ActionListener<AuthenticationResult> listener) {
listener.onResponse(AuthenticationResult.notHandled()); listener.onResponse(AuthenticationResult.notHandled());