Fix failures in SessionFactoryLoadBalancingTests (#39154)

This change aims to fix failures in the session factory load balancing
tests that mock failure scenarios. For these tests, we randomly shut
down ldap servers and bind a client socket to the port they were
listening on. Unfortunately, we would occasionally encounter failures
in these tests where a socket was already in use and/or the port
we expected to connect to was wrong and in fact was to one of the ldap
instances that should have been shut down.

The failures are caused by the behavior of certain operating systems
when it comes to binding ports and wildcard addresses. It is possible
for a separate application to be bound to a wildcard address and still
allow our code to bind to that port on a specific address. So when we
close the server socket and open the client socket, we are still able
to establish a connection since the other application is already
listening on that port on a wildcard address. Another variant is that
the os will allow a wildcard bind of a server socket when there is
already an application listening on that port for a specific address.

In order to do our best to prevent failures in these scenarios, this
change does the following:

1. Binds a client socket to all addresses in an awaitBusy
2. Adds assumption that we could bind all valid addresses
3. In the case that we still establish a connection to an address that
   we should not be able to, try to bind and expect a failure of not
   being connected

Closes #32190
This commit is contained in:
Jay Modi 2019-02-20 11:37:26 -07:00 committed by jaymode
parent 3d93011e32
commit af451459a5
No known key found for this signature in database
GPG Key ID: D859847567B3493D
3 changed files with 226 additions and 83 deletions

View File

@ -9,6 +9,7 @@ import com.unboundid.ldap.listener.InMemoryDirectoryServer;
import com.unboundid.ldap.sdk.LDAPException;
import com.unboundid.ldap.sdk.LDAPURL;
import com.unboundid.ldap.sdk.SimpleBindRequest;
import org.elasticsearch.common.network.NetworkAddress;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
@ -29,6 +30,7 @@ import org.elasticsearch.xpack.security.authc.ldap.support.LdapTestCase;
import org.junit.After;
import org.junit.Before;
import java.net.InetAddress;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardCopyOption;
@ -73,7 +75,12 @@ public class LdapSessionFactoryTests extends LdapTestCase {
public void testBindWithReadTimeout() throws Exception {
InMemoryDirectoryServer ldapServer = randomFrom(ldapServers);
String protocol = randomFrom("ldap", "ldaps");
String ldapUrl = new LDAPURL(protocol, "localhost", ldapServer.getListenPort(protocol), null, null, null, null).toString();
InetAddress listenAddress = ldapServer.getListenAddress(protocol);
if (listenAddress == null) {
listenAddress = InetAddress.getLoopbackAddress();
}
String ldapUrl = new LDAPURL(protocol, NetworkAddress.format(listenAddress), ldapServer.getListenPort(protocol),
null, null, null, null).toString();
String groupSearchBase = "o=sevenSeas";
String userTemplates = "cn={0},ou=people,o=sevenSeas";
@ -233,7 +240,12 @@ public class LdapSessionFactoryTests extends LdapTestCase {
*/
public void testSslTrustIsReloaded() throws Exception {
InMemoryDirectoryServer ldapServer = randomFrom(ldapServers);
String ldapUrl = new LDAPURL("ldaps", "localhost", ldapServer.getListenPort("ldaps"), null, null, null, null).toString();
InetAddress listenAddress = ldapServer.getListenAddress("ldaps");
if (listenAddress == null) {
listenAddress = InetAddress.getLoopbackAddress();
}
String ldapUrl = new LDAPURL("ldaps", NetworkAddress.format(listenAddress), ldapServer.getListenPort("ldaps"),
null, null, null, null).toString();
String groupSearchBase = "o=sevenSeas";
String userTemplates = "cn={0},ou=people,o=sevenSeas";

View File

@ -18,6 +18,7 @@ import com.unboundid.ldap.sdk.SimpleBindRequest;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.network.NetworkAddress;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
@ -25,6 +26,7 @@ import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.env.TestEnvironment;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.watcher.ResourceWatcherService;
import org.elasticsearch.xpack.core.XPackSettings;
import org.elasticsearch.xpack.core.security.authc.RealmConfig;
import org.elasticsearch.xpack.core.security.authc.ldap.LdapSessionFactorySettings;
import org.elasticsearch.xpack.core.security.authc.ldap.SearchGroupsResolverSettings;
@ -46,6 +48,7 @@ import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLServerSocketFactory;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.X509ExtendedKeyManager;
import java.net.InetAddress;
import java.security.AccessController;
import java.security.KeyStore;
import java.security.PrivilegedAction;
@ -76,7 +79,7 @@ public abstract class LdapTestCase extends ESTestCase {
for (int i = 0; i < numberOfLdapServers; i++) {
InMemoryDirectoryServerConfig serverConfig = new InMemoryDirectoryServerConfig("o=sevenSeas");
List<InMemoryListenerConfig> listeners = new ArrayList<>(2);
listeners.add(InMemoryListenerConfig.createLDAPConfig("ldap"));
listeners.add(InMemoryListenerConfig.createLDAPConfig("ldap", null, 0, null));
if (openLdapsPort()) {
final char[] ldapPassword = "ldap-password".toCharArray();
final KeyStore ks = CertParsingUtils.getKeyStoreFromPEM(
@ -85,7 +88,7 @@ public abstract class LdapTestCase extends ESTestCase {
ldapPassword
);
X509ExtendedKeyManager keyManager = CertParsingUtils.keyManager(ks, ldapPassword, KeyManagerFactory.getDefaultAlgorithm());
final SSLContext context = SSLContext.getInstance("TLSv1.2");
final SSLContext context = SSLContext.getInstance(XPackSettings.DEFAULT_SUPPORTED_PROTOCOLS.get(0));
context.init(new KeyManager[] { keyManager }, null, null);
SSLServerSocketFactory serverSocketFactory = context.getServerSocketFactory();
SSLSocketFactory clientSocketFactory = context.getSocketFactory();
@ -111,7 +114,7 @@ public abstract class LdapTestCase extends ESTestCase {
}
@After
public void stopLdap() throws Exception {
public void stopLdap() {
for (int i = 0; i < numberOfLdapServers; i++) {
ldapServers[i].shutDown(true);
}
@ -120,7 +123,11 @@ public abstract class LdapTestCase extends ESTestCase {
protected String[] ldapUrls() throws LDAPException {
List<String> urls = new ArrayList<>(numberOfLdapServers);
for (int i = 0; i < numberOfLdapServers; i++) {
LDAPURL url = new LDAPURL("ldap", "localhost", ldapServers[i].getListenPort(), null, null, null, null);
InetAddress listenAddress = ldapServers[i].getListenAddress();
if (listenAddress == null) {
listenAddress = InetAddress.getLoopbackAddress();
}
LDAPURL url = new LDAPURL("ldap", NetworkAddress.format(listenAddress), ldapServers[i].getListenPort(), null, null, null, null);
urls.add(url.toString());
}
return urls.toArray(Strings.EMPTY_ARRAY);

View File

@ -7,11 +7,16 @@ package org.elasticsearch.xpack.security.authc.ldap.support;
import com.unboundid.ldap.listener.InMemoryDirectoryServer;
import com.unboundid.ldap.sdk.LDAPConnection;
import com.unboundid.ldap.sdk.LDAPException;
import com.unboundid.ldap.sdk.SimpleBindRequest;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.SuppressForbidden;
import org.elasticsearch.common.network.InetAddressHelper;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.internal.io.IOUtils;
import org.elasticsearch.env.TestEnvironment;
import org.elasticsearch.mocksocket.MockServerSocket;
import org.elasticsearch.mocksocket.MockSocket;
@ -28,12 +33,17 @@ import org.junit.Before;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.NoRouteToHostException;
import java.net.Socket;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
@ -52,7 +62,7 @@ public class SessionFactoryLoadBalancingTests extends LdapTestCase {
}
@After
public void shutdown() throws InterruptedException {
public void shutdown() {
terminate(threadPool);
}
@ -62,29 +72,22 @@ public class SessionFactoryLoadBalancingTests extends LdapTestCase {
final int numberOfIterations = randomIntBetween(1, 5);
for (int iteration = 0; iteration < numberOfIterations; iteration++) {
for (int i = 0; i < numberOfLdapServers; i++) {
LDAPConnection connection = null;
try {
connection = LdapUtils.privilegedConnect(testSessionFactory.getServerSet()::getConnection);
try (LDAPConnection connection = LdapUtils.privilegedConnect(testSessionFactory.getServerSet()::getConnection)) {
assertThat(connection.getConnectedPort(), is(ldapServers[i].getListenPort()));
} finally {
if (connection != null) {
connection.close();
}
}
}
}
}
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/32190")
public void testRoundRobinWithFailures() throws Exception {
assumeTrue("at least one ldap server should be present for this test", ldapServers.length > 1);
assumeTrue("at least two ldap servers should be present for this test", ldapServers.length > 1);
logger.debug("using [{}] ldap servers, urls {}", ldapServers.length, ldapUrls());
TestSessionFactory testSessionFactory = createSessionFactory(LdapLoadBalancing.ROUND_ROBIN);
// create a list of ports
List<Integer> ports = new ArrayList<>(numberOfLdapServers);
for (int i = 0; i < ldapServers.length; i++) {
ports.add(ldapServers[i].getListenPort());
for (InMemoryDirectoryServer ldapServer : ldapServers) {
ports.add(ldapServer.getListenPort());
}
logger.debug("list of all ports {}", ports);
@ -94,18 +97,18 @@ public class SessionFactoryLoadBalancingTests extends LdapTestCase {
// get a subset to kill
final List<InMemoryDirectoryServer> ldapServersToKill = randomSubsetOf(numberToKill, ldapServers);
final List<InMemoryDirectoryServer> ldapServersList = Arrays.asList(ldapServers);
final InetAddress local = InetAddress.getByName("localhost");
final MockServerSocket mockServerSocket = new MockServerSocket(0, 0, local);
final MockServerSocket mockServerSocket = new MockServerSocket(0, 0);
final List<Thread> listenThreads = new ArrayList<>();
final CountDownLatch latch = new CountDownLatch(ldapServersToKill.size());
final CountDownLatch closeLatch = new CountDownLatch(1);
try {
final AtomicBoolean success = new AtomicBoolean(true);
for (InMemoryDirectoryServer ldapServerToKill : ldapServersToKill) {
final int index = ldapServersList.indexOf(ldapServerToKill);
assertThat(index, greaterThanOrEqualTo(0));
final Integer port = Integer.valueOf(ldapServers[index].getListenPort());
final int port = ldapServers[index].getListenPort();
logger.debug("shutting down server index [{}] listening on [{}]", index, port);
assertTrue(ports.remove(port));
assertTrue(ports.remove(Integer.valueOf(port)));
ldapServers[index].shutDown(true);
// when running multiple test jvms, there is a chance that something else could
@ -114,17 +117,9 @@ public class SessionFactoryLoadBalancingTests extends LdapTestCase {
// a mock server socket.
// NOTE: this is not perfect as there is a small amount of time between the shutdown
// of the ldap server and the opening of the socket
logger.debug("opening mock server socket listening on [{}]", port);
Runnable runnable = () -> {
try (Socket socket = openMockSocket(local, mockServerSocket.getLocalPort(), local, port)) {
logger.debug("opened socket [{}]", socket);
latch.countDown();
closeLatch.await();
logger.debug("closing socket [{}]", socket);
} catch (IOException | InterruptedException e) {
logger.debug("caught exception", e);
}
};
logger.debug("opening mock client sockets bound to [{}]", port);
Runnable runnable = new PortBlockingRunnable(mockServerSocket.getInetAddress(), mockServerSocket.getLocalPort(), port,
latch, closeLatch, success);
Thread thread = new Thread(runnable);
thread.start();
listenThreads.add(thread);
@ -133,14 +128,37 @@ public class SessionFactoryLoadBalancingTests extends LdapTestCase {
}
latch.await();
assumeTrue("Failed to open sockets on all addresses with the port that an LDAP server was bound to. Some operating systems " +
"allow binding to an address and port combination even if an application is bound to the port on a wildcard address",
success.get());
final int numberOfIterations = randomIntBetween(1, 5);
logger.debug("list of all open ports {}", ports);
// go one iteration through and attempt a bind
for (int iteration = 0; iteration < numberOfIterations; iteration++) {
logger.debug("iteration [{}]", iteration);
for (Integer port : ports) {
logger.debug("attempting connection with expected port [{}]", port);
try (LDAPConnection connection = LdapUtils.privilegedConnect(testSessionFactory.getServerSet()::getConnection)) {
LDAPConnection connection = null;
try {
do {
final LDAPConnection finalConnection =
LdapUtils.privilegedConnect(testSessionFactory.getServerSet()::getConnection);
connection = finalConnection;
logger.debug("established connection with port [{}] expected port [{}]",
finalConnection.getConnectedPort(), port);
if (finalConnection.getConnectedPort() != port) {
LDAPException e = expectThrows(LDAPException.class, () -> finalConnection.bind(new SimpleBindRequest()));
assertThat(e.getMessage(), containsString("not connected"));
finalConnection.close();
}
} while (connection.getConnectedPort() != port);
assertThat(connection.getConnectedPort(), is(port));
} finally {
if (connection != null) {
connection.close();
}
}
}
}
@ -160,76 +178,109 @@ public class SessionFactoryLoadBalancingTests extends LdapTestCase {
socket.setReuseAddress(true); // allow binding even if the previous socket is in timed wait state.
socket.setSoLinger(true, 0); // close immediately as we are not writing anything here.
socket.bind(new InetSocketAddress(localAddress, localPort));
SocketAccess.doPrivileged(() -> socket.connect(new InetSocketAddress(localAddress, remotePort)));
SocketAccess.doPrivileged(() -> socket.connect(new InetSocketAddress(remoteAddress, remotePort)));
return socket;
}
public void testFailover() throws Exception {
assumeTrue("at least one ldap server should be present for this test", ldapServers.length > 1);
assumeTrue("at least two ldap servers should be present for this test", ldapServers.length > 1);
logger.debug("using [{}] ldap servers, urls {}", ldapServers.length, ldapUrls());
TestSessionFactory testSessionFactory = createSessionFactory(LdapLoadBalancing.FAILOVER);
// first test that there is no round robin stuff going on
final int firstPort = ldapServers[0].getListenPort();
for (int i = 0; i < numberOfLdapServers; i++) {
LDAPConnection connection = null;
try {
connection = LdapUtils.privilegedConnect(testSessionFactory.getServerSet()::getConnection);
try (LDAPConnection connection = LdapUtils.privilegedConnect(testSessionFactory.getServerSet()::getConnection)) {
assertThat(connection.getConnectedPort(), is(firstPort));
} finally {
if (connection != null) {
connection.close();
}
}
}
logger.debug("shutting down server index [0] listening on [{}]", ldapServers[0].getListenPort());
// always kill the first one
ldapServers[0].shutDown(true);
assertThat(ldapServers[0].getListenPort(), is(-1));
// now randomly shutdown some others
// we need at least one good server. Hence the upper bound is number - 2 since we need at least
// one server to use!
InMemoryDirectoryServer[] allButFirstServer = Arrays.copyOfRange(ldapServers, 1, ldapServers.length);
final List<InMemoryDirectoryServer> ldapServersToKill;
if (ldapServers.length > 2) {
// kill at least one other server, but we need at least one good one. Hence the upper bound is number - 2 since we need at least
// one server to use!
final int numberToKill = randomIntBetween(1, numberOfLdapServers - 2);
InMemoryDirectoryServer[] allButFirstServer = Arrays.copyOfRange(ldapServers, 1, ldapServers.length);
// get a subset to kil
final List<InMemoryDirectoryServer> ldapServersToKill = randomSubsetOf(numberToKill, allButFirstServer);
final List<InMemoryDirectoryServer> ldapServersList = Arrays.asList(ldapServers);
for (InMemoryDirectoryServer ldapServerToKill : ldapServersToKill) {
final int index = ldapServersList.indexOf(ldapServerToKill);
assertThat(index, greaterThanOrEqualTo(1));
final Integer port = Integer.valueOf(ldapServers[index].getListenPort());
logger.debug("shutting down server index [{}] listening on [{}]", index, port);
ldapServers[index].shutDown(true);
assertThat(ldapServers[index].getListenPort(), is(-1));
}
ldapServersToKill = randomSubsetOf(numberToKill, allButFirstServer);
ldapServersToKill.add(ldapServers[0]); // always kill the first one
} else {
ldapServersToKill = Collections.singletonList(ldapServers[0]);
}
final List<InMemoryDirectoryServer> ldapServersList = Arrays.asList(ldapServers);
final MockServerSocket mockServerSocket = new MockServerSocket(0, 0);
final List<Thread> listenThreads = new ArrayList<>();
final CountDownLatch latch = new CountDownLatch(ldapServersToKill.size());
final CountDownLatch closeLatch = new CountDownLatch(1);
final AtomicBoolean success = new AtomicBoolean(true);
for (InMemoryDirectoryServer ldapServerToKill : ldapServersToKill) {
final int index = ldapServersList.indexOf(ldapServerToKill);
final int port = ldapServers[index].getListenPort();
logger.debug("shutting down server index [{}] listening on [{}]", index, port);
ldapServers[index].shutDown(true);
// when running multiple test jvms, there is a chance that something else could
// start listening on this port so we try to avoid this by creating a local socket
// that will be bound to the port the ldap server was running on and connecting to
// a mock server socket.
// NOTE: this is not perfect as there is a small amount of time between the shutdown
// of the ldap server and the opening of the socket
logger.debug("opening mock server socket listening on [{}]", port);
Runnable runnable = new PortBlockingRunnable(mockServerSocket.getInetAddress(), mockServerSocket.getLocalPort(), port,
latch, closeLatch, success);
Thread thread = new Thread(runnable);
thread.start();
listenThreads.add(thread);
assertThat(ldapServers[index].getListenPort(), is(-1));
}
int firstNonStoppedPort = -1;
// now we find the first that isn't stopped
for (int i = 0; i < numberOfLdapServers; i++) {
if (ldapServers[i].getListenPort() != -1) {
firstNonStoppedPort = ldapServers[i].getListenPort();
break;
}
}
logger.debug("first non stopped port [{}]", firstNonStoppedPort);
try {
latch.await();
assertThat(firstNonStoppedPort, not(-1));
final int numberOfIterations = randomIntBetween(1, 5);
for (int iteration = 0; iteration < numberOfIterations; iteration++) {
LDAPConnection connection = null;
try {
logger.debug("attempting connection with expected port [{}] iteration [{}]", firstNonStoppedPort, iteration);
connection = LdapUtils.privilegedConnect(testSessionFactory.getServerSet()::getConnection);
assertThat(connection.getConnectedPort(), is(firstNonStoppedPort));
} finally {
if (connection != null) {
connection.close();
assumeTrue("Failed to open sockets on all addresses with the port that an LDAP server was bound to. Some operating systems " +
"allow binding to an address and port combination even if an application is bound to the port on a wildcard address",
success.get());
int firstNonStoppedPort = -1;
// now we find the first that isn't stopped
for (int i = 0; i < numberOfLdapServers; i++) {
if (ldapServers[i].getListenPort() != -1) {
firstNonStoppedPort = ldapServers[i].getListenPort();
break;
}
}
logger.debug("first non stopped port [{}]", firstNonStoppedPort);
assertThat(firstNonStoppedPort, not(-1));
final int numberOfIterations = randomIntBetween(1, 5);
for (int iteration = 0; iteration < numberOfIterations; iteration++) {
logger.debug("attempting connection with expected port [{}] iteration [{}]", firstNonStoppedPort, iteration);
LDAPConnection connection = null;
try {
do {
final LDAPConnection finalConnection =
LdapUtils.privilegedConnect(testSessionFactory.getServerSet()::getConnection);
connection = finalConnection;
logger.debug("established connection with port [{}] expected port [{}]",
finalConnection.getConnectedPort(), firstNonStoppedPort);
if (finalConnection.getConnectedPort() != firstNonStoppedPort) {
LDAPException e = expectThrows(LDAPException.class, () -> finalConnection.bind(new SimpleBindRequest()));
assertThat(e.getMessage(), containsString("not connected"));
finalConnection.close();
}
} while (connection.getConnectedPort() != firstNonStoppedPort);
assertThat(connection.getConnectedPort(), is(firstNonStoppedPort));
} finally {
if (connection != null) {
connection.close();
}
}
}
} finally {
closeLatch.countDown();
mockServerSocket.close();
for (Thread t : listenThreads) {
t.join();
}
}
}
@ -245,6 +296,79 @@ public class SessionFactoryLoadBalancingTests extends LdapTestCase {
threadPool);
}
private class PortBlockingRunnable implements Runnable {
private final InetAddress serverAddress;
private final int serverPort;
private final int portToBind;
private final CountDownLatch latch;
private final CountDownLatch closeLatch;
private final AtomicBoolean success;
private PortBlockingRunnable(InetAddress serverAddress, int serverPort, int portToBind, CountDownLatch latch,
CountDownLatch closeLatch, AtomicBoolean success) {
this.serverAddress = serverAddress;
this.serverPort = serverPort;
this.portToBind = portToBind;
this.latch = latch;
this.closeLatch = closeLatch;
this.success = success;
}
@Override
public void run() {
final List<Socket> openedSockets = new ArrayList<>();
final List<InetAddress> blacklistedAddress = new ArrayList<>();
try {
final boolean allSocketsOpened = awaitBusy(() -> {
try {
final List<InetAddress> inetAddressesToBind = Arrays.stream(InetAddressHelper.getAllAddresses())
.filter(addr -> openedSockets.stream().noneMatch(s -> addr.equals(s.getLocalAddress())))
.filter(addr -> blacklistedAddress.contains(addr) == false)
.collect(Collectors.toList());
for (InetAddress localAddress : inetAddressesToBind) {
try {
final Socket socket = openMockSocket(serverAddress, serverPort, localAddress, portToBind);
openedSockets.add(socket);
logger.debug("opened socket [{}]", socket);
} catch (NoRouteToHostException e) {
logger.debug(new ParameterizedMessage("blacklisting address [{}] due to:", localAddress), e);
blacklistedAddress.add(localAddress);
}
}
return true;
} catch (IOException e) {
logger.debug(new ParameterizedMessage("caught exception while opening socket on [{}]", portToBind), e);
return false;
}
});
if (allSocketsOpened) {
latch.countDown();
} else {
success.set(false);
IOUtils.closeWhileHandlingException(openedSockets);
openedSockets.clear();
latch.countDown();
return;
}
} catch (InterruptedException e) {
logger.debug(new ParameterizedMessage("interrupted while trying to open sockets on [{}]", portToBind), e);
Thread.currentThread().interrupt();
}
try {
closeLatch.await();
} catch (InterruptedException e) {
logger.debug("caught exception while waiting for close latch", e);
Thread.currentThread().interrupt();
} finally {
logger.debug("closing sockets on [{}]", portToBind);
IOUtils.closeWhileHandlingException(openedSockets);
}
}
}
static class TestSessionFactory extends SessionFactory {
protected TestSessionFactory(RealmConfig config, SSLService sslService, ThreadPool threadPool) {