diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealmTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealmTests.java index 4ed04864041..2fed720e23c 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealmTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealmTests.java @@ -484,13 +484,27 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase { final int numberOfProcessors = Runtime.getRuntime().availableProcessors(); final int numberOfThreads = scaledRandomIntBetween((numberOfProcessors + 1) / 2, numberOfProcessors * 3); - final CountDownLatch latch = new CountDownLatch(1 + numberOfThreads); List threads = new ArrayList<>(numberOfThreads); final SecureString credsToUse = new SecureString(randomAlphaOfLength(12).toCharArray()); + + // we use a bunch of different latches here, the first `latch` is used to ensure all threads have been started + // before they start to execute. The `authWaitLatch` is there to ensure we have all threads waiting on the + // listener before we auth otherwise we may run into a race condition where we auth and one of the threads is + // not waiting on auth yet. Finally, the completedLatch is used to signal that each thread received a response! + final CountDownLatch latch = new CountDownLatch(1 + numberOfThreads); + final CountDownLatch authWaitLatch = new CountDownLatch(numberOfThreads); + final CountDownLatch completedLatch = new CountDownLatch(numberOfThreads); final CachingUsernamePasswordRealm realm = new CachingUsernamePasswordRealm(config, threadPool) { @Override protected void doAuthenticate(UsernamePasswordToken token, ActionListener listener) { authCounter.incrementAndGet(); + authWaitLatch.countDown(); + try { + authWaitLatch.await(); + } catch (InterruptedException e) { + logger.info("authentication was interrupted", e); + Thread.currentThread().interrupt(); + } // do something slow if (pwdHasher.verify(token.credentials(), passwordHash.toCharArray())) { listener.onFailure(new IllegalStateException("password auth should never succeed")); @@ -513,14 +527,17 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase { realm.authenticate(token, ActionListener.wrap((result) -> { if (result.isAuthenticated()) { + completedLatch.countDown(); throw new IllegalStateException("invalid password led to an authenticated result: " + result); } assertThat(result.getMessage(), containsString("password verification failed")); + completedLatch.countDown(); }, (e) -> { logger.error("caught exception", e); + completedLatch.countDown(); fail("unexpected exception - " + e); })); - + authWaitLatch.countDown(); } catch (InterruptedException e) { logger.error("thread was interrupted", e); Thread.currentThread().interrupt(); @@ -535,6 +552,7 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase { for (Thread thread : threads) { thread.join(); } + completedLatch.await(); assertEquals(1, authCounter.get()); }