diff --git a/solr/test-framework/src/java/org/apache/solr/cloud/ZkTestServer.java b/solr/test-framework/src/java/org/apache/solr/cloud/ZkTestServer.java index 9a044a24477..c4747c9e9d0 100644 --- a/solr/test-framework/src/java/org/apache/solr/cloud/ZkTestServer.java +++ b/solr/test-framework/src/java/org/apache/solr/cloud/ZkTestServer.java @@ -26,13 +26,27 @@ import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.Socket; import java.net.UnknownHostException; +import java.nio.channels.SelectionKey; +import java.nio.channels.SocketChannel; import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; import java.util.List; +import java.util.concurrent.ConcurrentHashMap; import javax.management.JMException; +import com.google.common.collect.Ordering; +import com.google.common.util.concurrent.AtomicLongMap; +import org.apache.zookeeper.KeeperException; +import org.apache.zookeeper.WatchedEvent; +import org.apache.zookeeper.Watcher; +import org.apache.zookeeper.data.Stat; import org.apache.zookeeper.jmx.ManagedUtil; +import org.apache.zookeeper.server.NIOServerCnxn; +import org.apache.zookeeper.server.NIOServerCnxnFactory; +import org.apache.zookeeper.server.ServerCnxn; import org.apache.zookeeper.server.ServerCnxnFactory; import org.apache.zookeeper.server.ServerConfig; import org.apache.zookeeper.server.SessionTracker.Session; @@ -58,11 +72,19 @@ public class ZkTestServer { private int theTickTime = TICK_TIME; + static public enum LimitViolationAction { + IGNORE, + REPORT, + FAIL, + } + class ZKServerMain { private ServerCnxnFactory cnxnFactory; private ZooKeeperServer zooKeeperServer; - + private LimitViolationAction violationReportAction = LimitViolationAction.REPORT; + private WatchLimiter limiter = new WatchLimiter(1, LimitViolationAction.IGNORE); + protected void initializeAndRun(String[] args) throws ConfigException, IOException { try { @@ -81,6 +103,182 @@ public class ZkTestServer { runFromConfig(config); } + private class WatchLimit { + private long limit; + private final String desc; + + private LimitViolationAction action; + private AtomicLongMap counters = AtomicLongMap.create(); + private ConcurrentHashMap maxCounters = new ConcurrentHashMap<>(); + + WatchLimit(long limit, String desc, LimitViolationAction action) { + this.limit = limit; + this.desc = desc; + this.action = action; + } + + public void setAction(LimitViolationAction action) { + this.action = action; + } + + public void setLimit(long limit) { + this.limit = limit; + } + + public void updateForWatch(String key, Watcher watcher) { + if (watcher != null) { + log.debug("Watch added: {}: {}", desc, key); + long count = counters.incrementAndGet(key); + Long lastCount = maxCounters.get(key); + if (lastCount == null || count > lastCount) { + maxCounters.put(key, count); + } + if (count > limit && action != LimitViolationAction.IGNORE) { + String msg = "Number of watches created in parallel for data: " + key + + ", type: " + desc + " exceeds limit (" + count + " > " + limit + ")"; + log.warn("{}", msg); + if (action == LimitViolationAction.FAIL) throw new AssertionError(msg); + } + } + } + + public void updateForFire(WatchedEvent event) { + log.debug("Watch fired: {}: {}", desc, event.getPath()); + counters.decrementAndGet(event.getPath()); + } + + private String reportLimitViolations() { + Object[] maxKeys = maxCounters.keySet().toArray(); + Arrays.sort(maxKeys, new Comparator() { + private final Comparator valComp = Ordering.natural().reverse(); + @Override + public int compare(Object o1, Object o2) { + return valComp.compare(maxCounters.get(o1), maxCounters.get(o2)); + } + }); + + StringBuilder sb = new StringBuilder(); + boolean first = true; + for (Object key : maxKeys) { + long value = maxCounters.get(key); + if (value <= limit) continue; + if (first) { + sb.append("\nMaximum concurrent ").append(desc).append(" watches above limit:\n\n"); + first = false; + } + sb.append("\t").append(maxCounters.get(key)).append('\t').append(key).append('\n'); + } + return sb.toString(); + } + } + + public class WatchLimiter { + WatchLimit statLimit; + WatchLimit dataLimit; + WatchLimit childrenLimit; + + private WatchLimiter (long limit, LimitViolationAction action) { + statLimit = new WatchLimit(limit, "create/delete", action); + dataLimit = new WatchLimit(limit, "data", action); + childrenLimit = new WatchLimit(limit, "children", action); + } + + public void setAction(LimitViolationAction action) { + statLimit.setAction(action); + dataLimit.setAction(action); + childrenLimit.setAction(action); + } + + public void setLimit(long limit) { + statLimit.setLimit(limit); + dataLimit.setLimit(limit); + childrenLimit.setLimit(limit); + } + + public String reportLimitViolations() { + return statLimit.reportLimitViolations() + + dataLimit.reportLimitViolations() + + childrenLimit.reportLimitViolations(); + } + + private void updateForFire(WatchedEvent event) { + switch (event.getType()) { + case None: + break; + case NodeCreated: + case NodeDeleted: + statLimit.updateForFire(event); + break; + case NodeDataChanged: + dataLimit.updateForFire(event); + break; + case NodeChildrenChanged: + childrenLimit.updateForFire(event); + break; + } + } + } + + private class TestServerCnxn extends NIOServerCnxn { + + private final WatchLimiter limiter; + + public TestServerCnxn(ZooKeeperServer zk, SocketChannel sock, SelectionKey sk, + NIOServerCnxnFactory factory, WatchLimiter limiter) throws IOException { + super(zk, sock, sk, factory); + this.limiter = limiter; + } + + @Override + public synchronized void process(WatchedEvent event) { + limiter.updateForFire(event); + super.process(event); + } + } + + private class TestServerCnxnFactory extends NIOServerCnxnFactory { + + private final WatchLimiter limiter; + + public TestServerCnxnFactory(WatchLimiter limiter) throws IOException { + super(); + this.limiter = limiter; + } + + @Override + protected NIOServerCnxn createConnection(SocketChannel sock, SelectionKey sk) throws IOException { + return new TestServerCnxn(zkServer, sock, sk, this, limiter); + } + } + + private class TestZKDatabase extends ZKDatabase { + + private final WatchLimiter limiter; + + public TestZKDatabase(FileTxnSnapLog snapLog, WatchLimiter limiter) { + super(snapLog); + this.limiter = limiter; + } + + @Override + public Stat statNode(String path, ServerCnxn serverCnxn) throws KeeperException.NoNodeException { + limiter.statLimit.updateForWatch(path, serverCnxn); + return super.statNode(path, serverCnxn); + } + + @Override + public byte[] getData(String path, Stat stat, Watcher watcher) throws KeeperException.NoNodeException { + limiter.dataLimit.updateForWatch(path, watcher); + return super.getData(path, stat, watcher); + } + + @Override + public List getChildren(String path, Stat stat, Watcher watcher) throws KeeperException.NoNodeException { + limiter.childrenLimit.updateForWatch(path, watcher); + return super.getChildren(path, stat, watcher); + } + } + /** * Run from a ServerConfig. * @param config ServerConfig to use. @@ -93,15 +291,12 @@ public class ZkTestServer { // so rather than spawning another thread, we will just call // run() in this thread. // create a file logger url from the command line args - zooKeeperServer = new ZooKeeperServer(); - FileTxnSnapLog ftxn = new FileTxnSnapLog(new File( config.getDataLogDir()), new File(config.getDataDir())); - zooKeeperServer.setTxnLogFactory(ftxn); - zooKeeperServer.setTickTime(config.getTickTime()); - zooKeeperServer.setMinSessionTimeout(config.getMinSessionTimeout()); - zooKeeperServer.setMaxSessionTimeout(config.getMaxSessionTimeout()); - cnxnFactory = ServerCnxnFactory.createFactory(); + zooKeeperServer = new ZooKeeperServer(ftxn, config.getTickTime(), + config.getMinSessionTimeout(), config.getMaxSessionTimeout(), + null /* this is not used */, new TestZKDatabase(ftxn, limiter)); + cnxnFactory = new TestServerCnxnFactory(limiter); cnxnFactory.configure(config.getClientPortAddress(), config.getMaxClientCnxns()); cnxnFactory.startup(zooKeeperServer); @@ -109,6 +304,15 @@ public class ZkTestServer { // if (zooKeeperServer.isRunning()) { zkServer.shutdown(); // } + if (violationReportAction != LimitViolationAction.IGNORE) { + String limitViolations = limiter.reportLimitViolations(); + if (!limitViolations.isEmpty()) { + log.warn("Watch limit violations: {}", limitViolations); + if (violationReportAction == LimitViolationAction.FAIL) { + throw new AssertionError("Parallel watch limits violated"); + } + } + } } catch (InterruptedException e) { // warn, but generally this is ok log.warn("Server interrupted", e); @@ -153,6 +357,14 @@ public class ZkTestServer { } return port; } + + public void setViolationReportAction(LimitViolationAction violationReportAction) { + this.violationReportAction = violationReportAction; + } + + public WatchLimiter getLimiter() { + return limiter; + } } public ZkTestServer(String zkDir) { @@ -162,6 +374,16 @@ public class ZkTestServer { public ZkTestServer(String zkDir, int port) { this.zkDir = zkDir; this.clientPort = port; + String reportAction = System.getProperty("tests.zk.violationReportAction"); + if (reportAction != null) { + log.info("Overriding violation report action to: {}", reportAction); + setViolationReportAction(LimitViolationAction.valueOf(reportAction)); + } + String limiterAction = System.getProperty("tests.zk.limiterAction"); + if (limiterAction != null) { + log.info("Overriding limiter action to: {}", limiterAction); + getLimiter().setAction(LimitViolationAction.valueOf(limiterAction)); + } } public String getZkHost() { @@ -182,14 +404,17 @@ public class ZkTestServer { public long getSessionId() { return sessionId; } + @Override public int getTimeout() { return 4000; } + @Override public boolean isClosing() { return false; - }}); + } + }); } public void run() throws InterruptedException { @@ -239,14 +464,14 @@ public class ZkTestServer { try { port = getPort(); } catch(IllegalStateException e) { - + } while (port < 1) { Thread.sleep(100); try { port = getPort(); } catch(IllegalStateException e) { - + } if (cnt == 500) { throw new RuntimeException("Could not get the port for ZooKeeper server"); @@ -310,31 +535,29 @@ public class ZkTestServer { public static String send4LetterWord(String host, int port, String cmd) throws IOException { - log.info("connecting to " + host + " " + port); - Socket sock = new Socket(host, port); - BufferedReader reader = null; - try { - OutputStream outstream = sock.getOutputStream(); - outstream.write(cmd.getBytes(StandardCharsets.US_ASCII)); - outstream.flush(); - // this replicates NC - close the output stream before reading - sock.shutdownOutput(); + log.info("connecting to " + host + " " + port); + BufferedReader reader = null; + try (Socket sock = new Socket(host, port)) { + OutputStream outstream = sock.getOutputStream(); + outstream.write(cmd.getBytes(StandardCharsets.US_ASCII)); + outstream.flush(); + // this replicates NC - close the output stream before reading + sock.shutdownOutput(); - reader = - new BufferedReader( - new InputStreamReader(sock.getInputStream(), "US-ASCII")); - StringBuilder sb = new StringBuilder(); - String line; - while((line = reader.readLine()) != null) { - sb.append(line + "\n"); - } - return sb.toString(); - } finally { - sock.close(); - if (reader != null) { - reader.close(); - } + reader = + new BufferedReader( + new InputStreamReader(sock.getInputStream(), "US-ASCII")); + StringBuilder sb = new StringBuilder(); + String line; + while ((line = reader.readLine()) != null) { + sb.append(line).append("\n"); } + return sb.toString(); + } finally { + if (reader != null) { + reader.close(); + } + } } public static List parseHostPortList(String hplist) { @@ -364,4 +587,12 @@ public class ZkTestServer { public String getZkDir() { return zkDir; } + + public void setViolationReportAction(LimitViolationAction violationReportAction) { + zkServer.setViolationReportAction(violationReportAction); + } + + public ZKServerMain.WatchLimiter getLimiter() { + return zkServer.getLimiter(); + } }