Test: fix race in auth result propagation test

This commit fixes a race condition in a test introduced by #36900 that
verifies concurrent authentications get a result propagated from the
first thread that attempts to authenticate. Previously, a thread may
be in a state where it had not attempted to authenticate when the first
thread that authenticates finishes the authentication, which would
cause the test to fail as there would be an additional authentication
attempt. This change adds additional latches to ensure all threads have
attempted to authenticate before a result gets returned in the
thread that is performing authentication.
This commit is contained in:
jaymode 2019-01-09 12:17:43 -07:00
parent 722b850efd
commit c71060fa01
No known key found for this signature in database
GPG Key ID: D859847567B3493D
1 changed files with 20 additions and 2 deletions

View File

@ -484,13 +484,27 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase {
final int numberOfProcessors = Runtime.getRuntime().availableProcessors(); final int numberOfProcessors = Runtime.getRuntime().availableProcessors();
final int numberOfThreads = scaledRandomIntBetween((numberOfProcessors + 1) / 2, numberOfProcessors * 3); final int numberOfThreads = scaledRandomIntBetween((numberOfProcessors + 1) / 2, numberOfProcessors * 3);
final CountDownLatch latch = new CountDownLatch(1 + numberOfThreads);
List<Thread> threads = new ArrayList<>(numberOfThreads); List<Thread> threads = new ArrayList<>(numberOfThreads);
final SecureString credsToUse = new SecureString(randomAlphaOfLength(12).toCharArray()); final SecureString credsToUse = new SecureString(randomAlphaOfLength(12).toCharArray());
// we use a bunch of different latches here, the first `latch` is used to ensure all threads have been started
// before they start to execute. The `authWaitLatch` is there to ensure we have all threads waiting on the
// listener before we auth otherwise we may run into a race condition where we auth and one of the threads is
// not waiting on auth yet. Finally, the completedLatch is used to signal that each thread received a response!
final CountDownLatch latch = new CountDownLatch(1 + numberOfThreads);
final CountDownLatch authWaitLatch = new CountDownLatch(numberOfThreads);
final CountDownLatch completedLatch = new CountDownLatch(numberOfThreads);
final CachingUsernamePasswordRealm realm = new CachingUsernamePasswordRealm(config, threadPool) { final CachingUsernamePasswordRealm realm = new CachingUsernamePasswordRealm(config, threadPool) {
@Override @Override
protected void doAuthenticate(UsernamePasswordToken token, ActionListener<AuthenticationResult> listener) { protected void doAuthenticate(UsernamePasswordToken token, ActionListener<AuthenticationResult> listener) {
authCounter.incrementAndGet(); authCounter.incrementAndGet();
authWaitLatch.countDown();
try {
authWaitLatch.await();
} catch (InterruptedException e) {
logger.info("authentication was interrupted", e);
Thread.currentThread().interrupt();
}
// do something slow // do something slow
if (pwdHasher.verify(token.credentials(), passwordHash.toCharArray())) { if (pwdHasher.verify(token.credentials(), passwordHash.toCharArray())) {
listener.onFailure(new IllegalStateException("password auth should never succeed")); listener.onFailure(new IllegalStateException("password auth should never succeed"));
@ -513,14 +527,17 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase {
realm.authenticate(token, ActionListener.wrap((result) -> { realm.authenticate(token, ActionListener.wrap((result) -> {
if (result.isAuthenticated()) { if (result.isAuthenticated()) {
completedLatch.countDown();
throw new IllegalStateException("invalid password led to an authenticated result: " + result); throw new IllegalStateException("invalid password led to an authenticated result: " + result);
} }
assertThat(result.getMessage(), containsString("password verification failed")); assertThat(result.getMessage(), containsString("password verification failed"));
completedLatch.countDown();
}, (e) -> { }, (e) -> {
logger.error("caught exception", e); logger.error("caught exception", e);
completedLatch.countDown();
fail("unexpected exception - " + e); fail("unexpected exception - " + e);
})); }));
authWaitLatch.countDown();
} catch (InterruptedException e) { } catch (InterruptedException e) {
logger.error("thread was interrupted", e); logger.error("thread was interrupted", e);
Thread.currentThread().interrupt(); Thread.currentThread().interrupt();
@ -535,6 +552,7 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase {
for (Thread thread : threads) { for (Thread thread : threads) {
thread.join(); thread.join();
} }
completedLatch.await();
assertEquals(1, authCounter.get()); assertEquals(1, authCounter.get());
} }