diff --git a/elasticsearch/x-pack/shield/src/test/java/org/elasticsearch/shield/authc/ldap/support/SessionFactoryLoadBalancingTests.java b/elasticsearch/x-pack/shield/src/test/java/org/elasticsearch/shield/authc/ldap/support/SessionFactoryLoadBalancingTests.java index 843d6bd457f..423edb20dc7 100644 --- a/elasticsearch/x-pack/shield/src/test/java/org/elasticsearch/shield/authc/ldap/support/SessionFactoryLoadBalancingTests.java +++ b/elasticsearch/x-pack/shield/src/test/java/org/elasticsearch/shield/authc/ldap/support/SessionFactoryLoadBalancingTests.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.shield.authc.ldap.support; +import com.unboundid.ldap.listener.InMemoryDirectoryServer; import com.unboundid.ldap.sdk.LDAPConnection; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.shield.authc.RealmConfig; @@ -12,8 +13,10 @@ import org.elasticsearch.shield.authc.support.SecuredString; import org.elasticsearch.shield.ssl.ClientSSLService; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; @@ -43,6 +46,7 @@ public class SessionFactoryLoadBalancingTests extends LdapTestCase { public void testRoundRobinWithFailures() throws Exception { assumeTrue("at least one ldap server should be present for this test", ldapServers.length > 1); + logger.debug("using [{}] ldap servers, urls {}", ldapServers.length, ldapUrls()); TestSessionFactory testSessionFactory = createSessionFactory(LdapLoadBalancing.ROUND_ROBIN); // create a list of ports @@ -50,19 +54,31 @@ public class SessionFactoryLoadBalancingTests extends LdapTestCase { for (int i = 0; i < ldapServers.length; i++) { ports.add(ldapServers[i].getListenPort()); } + logger.debug("list of all ports {}", ports); - int numberToKill = randomIntBetween(1, numberOfLdapServers - 1); - for (int i = 0; i < numberToKill; i++) { - int index = randomIntBetween(0, numberOfLdapServers - 1); - ports.remove(Integer.valueOf(ldapServers[index].getListenPort())); + final int numberToKill = randomIntBetween(1, numberOfLdapServers - 1); + logger.debug("killing [{}] servers", numberToKill); + + // get a subset to kil + final List ldapServersToKill = randomSubsetOf(numberToKill, ldapServers); + final List ldapServersList = Arrays.asList(ldapServers); + for (InMemoryDirectoryServer ldapServerToKill : ldapServersToKill) { + final int index = ldapServersList.indexOf(ldapServerToKill); + assertThat(index, greaterThanOrEqualTo(0)); + final Integer port = Integer.valueOf(ldapServers[index].getListenPort()); + logger.debug("shutting down server index [{}] listening on [{}]", index, port); + assertTrue(ports.remove(port)); ldapServers[index].shutDown(true); + assertThat(ldapServers[index].getListenPort(), is(-1)); } final int numberOfIterations = randomIntBetween(1, 5); for (int iteration = 0; iteration < numberOfIterations; iteration++) { + logger.debug("iteration [{}]", iteration); for (Integer port : ports) { LDAPConnection connection = null; try { + logger.debug("attempting connection with expected port [{}]", port); connection = testSessionFactory.getServerSet().getConnection(); assertThat(connection.getConnectedPort(), is(port)); } finally { @@ -76,6 +92,7 @@ public class SessionFactoryLoadBalancingTests extends LdapTestCase { public void testFailover() throws Exception { assumeTrue("at least one ldap server should be present for this test", ldapServers.length > 1); + logger.debug("using [{}] ldap servers, urls {}", ldapServers.length, ldapUrls()); TestSessionFactory testSessionFactory = createSessionFactory(LdapLoadBalancing.FAILOVER); // first test that there is no round robin stuff going on @@ -92,32 +109,46 @@ public class SessionFactoryLoadBalancingTests extends LdapTestCase { } } - List stoppedServers = new ArrayList<>(); - // now we should kill some servers including the first one - int numberToKill = randomIntBetween(1, numberOfLdapServers - 1); - // always kill the first one, but don't add to the list + logger.debug("shutting down server index [0] listening on [{}]", ldapServers[0].getListenPort()); + // always kill the first one ldapServers[0].shutDown(true); - stoppedServers.add(0); - for (int i = 0; i < numberToKill - 1; i++) { - int index = randomIntBetween(1, numberOfLdapServers - 1); - ldapServers[index].shutDown(true); - stoppedServers.add(index); + assertThat(ldapServers[0].getListenPort(), is(-1)); + + // now randomly shutdown some others + if (ldapServers.length > 2) { + // kill at least one other server, but we need at least one good one. Hence the upper bound is number - 2 since we need at least + // one server to use! + final int numberToKill = randomIntBetween(1, numberOfLdapServers - 2); + InMemoryDirectoryServer[] allButFirstServer = Arrays.copyOfRange(ldapServers, 1, ldapServers.length); + // get a subset to kil + final List ldapServersToKill = randomSubsetOf(numberToKill, allButFirstServer); + final List ldapServersList = Arrays.asList(ldapServers); + for (InMemoryDirectoryServer ldapServerToKill : ldapServersToKill) { + final int index = ldapServersList.indexOf(ldapServerToKill); + assertThat(index, greaterThanOrEqualTo(1)); + final Integer port = Integer.valueOf(ldapServers[index].getListenPort()); + logger.debug("shutting down server index [{}] listening on [{}]", index, port); + ldapServers[index].shutDown(true); + assertThat(ldapServers[index].getListenPort(), is(-1)); + } } int firstNonStoppedPort = -1; // now we find the first that isn't stopped for (int i = 0; i < numberOfLdapServers; i++) { - if (stoppedServers.contains(i) == false) { + if (ldapServers[i].getListenPort() != -1) { firstNonStoppedPort = ldapServers[i].getListenPort(); break; } } + logger.debug("first non stopped port [{}]", firstNonStoppedPort); assertThat(firstNonStoppedPort, not(-1)); final int numberOfIterations = randomIntBetween(1, 5); for (int iteration = 0; iteration < numberOfIterations; iteration++) { LDAPConnection connection = null; try { + logger.debug("attempting connection with expected port [{}] iteration [{}]", firstNonStoppedPort, iteration); connection = testSessionFactory.getServerSet().getConnection(); assertThat(connection.getConnectedPort(), is(firstNonStoppedPort)); } finally {