diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index ae0a157b46c..eb3d1159a53 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -56,7 +56,7 @@ Improvements Optimizations -* LUCENE-7330: Speed up conjunction queries. (Adrien Grand) +* LUCENE-7330, LUCENE-7339: Speed up conjunction queries. (Adrien Grand) Other diff --git a/lucene/core/src/java/org/apache/lucene/search/ConjunctionDISI.java b/lucene/core/src/java/org/apache/lucene/search/ConjunctionDISI.java index 205a349b73b..43d03b27978 100644 --- a/lucene/core/src/java/org/apache/lucene/search/ConjunctionDISI.java +++ b/lucene/core/src/java/org/apache/lucene/search/ConjunctionDISI.java @@ -19,11 +19,15 @@ package org.apache.lucene.search; import java.io.IOException; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.List; import org.apache.lucene.search.spans.Spans; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.CollectionUtil; /** A conjunction of DocIdSetIterators. @@ -47,11 +51,7 @@ public final class ConjunctionDISI extends DocIdSetIterator { addScorer(scorer, allIterators, twoPhaseIterators); } - DocIdSetIterator iterator = new ConjunctionDISI(allIterators); - if (twoPhaseIterators.isEmpty() == false) { - iterator = TwoPhaseIterator.asDocIdSetIterator(new ConjunctionTwoPhaseIterator(iterator, twoPhaseIterators)); - } - return iterator; + return createConjunction(allIterators, twoPhaseIterators); } /** Create a conjunction over the provided DocIdSetIterators. Note that the @@ -68,11 +68,7 @@ public final class ConjunctionDISI extends DocIdSetIterator { addIterator(iterator, allIterators, twoPhaseIterators); } - DocIdSetIterator iterator = new ConjunctionDISI(allIterators); - if (twoPhaseIterators.isEmpty() == false) { - iterator = TwoPhaseIterator.asDocIdSetIterator(new ConjunctionTwoPhaseIterator(iterator, twoPhaseIterators)); - } - return iterator; + return createConjunction(allIterators, twoPhaseIterators); } /** Create a conjunction over the provided {@link Spans}. Note that the @@ -89,11 +85,7 @@ public final class ConjunctionDISI extends DocIdSetIterator { addSpans(spans, allIterators, twoPhaseIterators); } - DocIdSetIterator iterator = new ConjunctionDISI(allIterators); - if (twoPhaseIterators.isEmpty() == false) { - iterator = TwoPhaseIterator.asDocIdSetIterator(new ConjunctionTwoPhaseIterator(iterator, twoPhaseIterators)); - } - return iterator; + return createConjunction(allIterators, twoPhaseIterators); } /** Adds the scorer, possibly splitting up into two phases or collapsing if it is another conjunction */ @@ -127,6 +119,10 @@ public final class ConjunctionDISI extends DocIdSetIterator { allIterators.add(conjunction.lead1); allIterators.add(conjunction.lead2); Collections.addAll(allIterators, conjunction.others); + } else if (disi.getClass() == BitSetConjunctionDISI.class) { + BitSetConjunctionDISI conjunction = (BitSetConjunctionDISI) disi; + allIterators.add(conjunction.lead); + Collections.addAll(allIterators, conjunction.bitSetIterators); } else { allIterators.add(disi); } @@ -141,6 +137,41 @@ public final class ConjunctionDISI extends DocIdSetIterator { } } + private static DocIdSetIterator createConjunction( + List allIterators, + List twoPhaseIterators) { + long minCost = allIterators.stream().mapToLong(DocIdSetIterator::cost).min().getAsLong(); + List bitSetIterators = new ArrayList<>(); + List iterators = new ArrayList<>(); + for (DocIdSetIterator iterator : allIterators) { + if (iterator.cost() > minCost && iterator instanceof BitSetIterator) { + // we put all bitset iterators into bitSetIterators + // except if they have the minimum cost, since we need + // them to lead the iteration in that case + bitSetIterators.add((BitSetIterator) iterator); + } else { + iterators.add(iterator); + } + } + + DocIdSetIterator disi; + if (iterators.size() == 1) { + disi = iterators.get(0); + } else { + disi = new ConjunctionDISI(iterators); + } + + if (bitSetIterators.size() > 0) { + disi = new BitSetConjunctionDISI(disi, bitSetIterators); + } + + if (twoPhaseIterators.isEmpty() == false) { + disi = TwoPhaseIterator.asDocIdSetIterator(new ConjunctionTwoPhaseIterator(disi, twoPhaseIterators)); + } + + return disi; + } + final DocIdSetIterator lead1, lead2; final DocIdSetIterator[] others; @@ -214,6 +245,69 @@ public final class ConjunctionDISI extends DocIdSetIterator { return lead1.cost(); // overestimate } + /** Conjunction between a {@link DocIdSetIterator} and one or more {@link BitSetIterator}s. */ + private static class BitSetConjunctionDISI extends DocIdSetIterator { + + private final DocIdSetIterator lead; + private final BitSetIterator[] bitSetIterators; + private final BitSet[] bitSets; + private final int minLength; + + BitSetConjunctionDISI(DocIdSetIterator lead, Collection bitSetIterators) { + this.lead = lead; + assert bitSetIterators.size() > 0; + this.bitSetIterators = bitSetIterators.toArray(new BitSetIterator[0]); + // Put the least costly iterators first so that we exit as soon as possible + ArrayUtil.timSort(this.bitSetIterators, (a, b) -> Long.compare(a.cost(), b.cost())); + this.bitSets = new BitSet[this.bitSetIterators.length]; + int minLen = Integer.MAX_VALUE; + for (int i = 0; i < this.bitSetIterators.length; ++i) { + BitSet bitSet = this.bitSetIterators[i].getBitSet(); + this.bitSets[i] = bitSet; + minLen = Math.min(minLen, bitSet.length()); + } + this.minLength = minLen; + } + + @Override + public int docID() { + return lead.docID(); + } + + @Override + public int nextDoc() throws IOException { + return doNext(lead.nextDoc()); + } + + @Override + public int advance(int target) throws IOException { + return doNext(lead.advance(target)); + } + + private int doNext(int doc) throws IOException { + advanceLead: for (;; doc = lead.nextDoc()) { + if (doc >= minLength) { + return NO_MORE_DOCS; + } + for (BitSet bitSet : bitSets) { + if (bitSet.get(doc) == false) { + continue advanceLead; + } + } + for (BitSetIterator iterator : bitSetIterators) { + iterator.setDocId(doc); + } + return doc; + } + } + + @Override + public long cost() { + return lead.cost(); + } + + } + /** * {@link TwoPhaseIterator} implementing a conjunction. */ diff --git a/lucene/core/src/java/org/apache/lucene/search/LRUQueryCache.java b/lucene/core/src/java/org/apache/lucene/search/LRUQueryCache.java index 859864572b1..f55536505e2 100644 --- a/lucene/core/src/java/org/apache/lucene/search/LRUQueryCache.java +++ b/lucene/core/src/java/org/apache/lucene/search/LRUQueryCache.java @@ -40,6 +40,8 @@ import org.apache.lucene.index.Term; import org.apache.lucene.index.TieredMergePolicy; import org.apache.lucene.util.Accountable; import org.apache.lucene.util.Accountables; +import org.apache.lucene.util.BitDocIdSet; +import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.RamUsageEstimator; import org.apache.lucene.util.RoaringDocIdSet; @@ -502,9 +504,39 @@ public class LRUQueryCache implements QueryCache, Accountable { } /** - * Default cache implementation: uses {@link RoaringDocIdSet}. + * Default cache implementation: uses {@link RoaringDocIdSet} for sets that + * have a density < 1% and a {@link BitDocIdSet} over a {@link FixedBitSet} + * otherwise. */ protected DocIdSet cacheImpl(BulkScorer scorer, int maxDoc) throws IOException { + if (scorer.cost() * 100 >= maxDoc) { + // FixedBitSet is faster for dense sets and will enable the random-access + // optimization in ConjunctionDISI + return cacheIntoBitSet(scorer, maxDoc); + } else { + return cacheIntoRoaringDocIdSet(scorer, maxDoc); + } + } + + private static DocIdSet cacheIntoBitSet(BulkScorer scorer, int maxDoc) throws IOException { + final FixedBitSet bitSet = new FixedBitSet(maxDoc); + long cost[] = new long[1]; + scorer.score(new LeafCollector() { + + @Override + public void setScorer(Scorer scorer) throws IOException {} + + @Override + public void collect(int doc) throws IOException { + cost[0]++; + bitSet.set(doc); + } + + }, null); + return new BitDocIdSet(bitSet, cost[0]); + } + + private static DocIdSet cacheIntoRoaringDocIdSet(BulkScorer scorer, int maxDoc) throws IOException { RoaringDocIdSet.Builder builder = new RoaringDocIdSet.Builder(maxDoc); scorer.score(new LeafCollector() { diff --git a/lucene/core/src/java/org/apache/lucene/util/BitSetIterator.java b/lucene/core/src/java/org/apache/lucene/util/BitSetIterator.java index 1a51d801aac..9e88c6d7936 100644 --- a/lucene/core/src/java/org/apache/lucene/util/BitSetIterator.java +++ b/lucene/core/src/java/org/apache/lucene/util/BitSetIterator.java @@ -62,11 +62,21 @@ public class BitSetIterator extends DocIdSetIterator { this.cost = cost; } + /** Return the wrapped {@link BitSet}. */ + public BitSet getBitSet() { + return bits; + } + @Override public int docID() { return doc; } + /** Set the current doc id that this iterator is on. */ + public void setDocId(int docId) { + this.doc = docId; + } + @Override public int nextDoc() { return advance(doc + 1); diff --git a/lucene/core/src/test/org/apache/lucene/search/TestConjunctionDISI.java b/lucene/core/src/test/org/apache/lucene/search/TestConjunctionDISI.java index dcb666417c7..9835f35f1c1 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestConjunctionDISI.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestConjunctionDISI.java @@ -30,12 +30,18 @@ import org.apache.lucene.util.TestUtil; public class TestConjunctionDISI extends LuceneTestCase { - private static TwoPhaseIterator approximation(final DocIdSetIterator iterator, final FixedBitSet confirmed) { - return new TwoPhaseIterator(iterator) { + private static TwoPhaseIterator approximation(DocIdSetIterator iterator, final FixedBitSet confirmed) { + DocIdSetIterator approximation; + if (random().nextBoolean()) { + approximation = anonymizeIterator(iterator); + } else { + approximation = iterator; + } + return new TwoPhaseIterator(approximation) { @Override public boolean matches() throws IOException { - return confirmed.get(iterator.docID()); + return confirmed.get(approximation.docID()); } @Override @@ -45,6 +51,33 @@ public class TestConjunctionDISI extends LuceneTestCase { }; } + /** Return an anonym class so that ConjunctionDISI cannot optimize it + * like it does eg. for BitSetIterators. */ + private static DocIdSetIterator anonymizeIterator(DocIdSetIterator it) { + return new DocIdSetIterator() { + + @Override + public int nextDoc() throws IOException { + return it.nextDoc(); + } + + @Override + public int docID() { + return it.docID(); + } + + @Override + public long cost() { + return it.docID(); + } + + @Override + public int advance(int target) throws IOException { + return it.advance(target); + } + }; + } + private static Scorer scorer(TwoPhaseIterator twoPhaseIterator) { return scorer(TwoPhaseIterator.asDocIdSetIterator(twoPhaseIterator), twoPhaseIterator); } @@ -168,16 +201,24 @@ public class TestConjunctionDISI extends LuceneTestCase { final Scorer[] iterators = new Scorer[numIterators]; for (int i = 0; i < iterators.length; ++i) { final FixedBitSet set = randomSet(maxDoc); - if (random().nextBoolean()) { - // simple iterator - sets[i] = set; - iterators[i] = new ConstantScoreScorer(null, 0f, new BitDocIdSet(set).iterator()); - } else { - // scorer with approximation - final FixedBitSet confirmed = clearRandomBits(set); - sets[i] = confirmed; - final TwoPhaseIterator approximation = approximation(new BitDocIdSet(set).iterator(), confirmed); - iterators[i] = scorer(approximation); + switch (random().nextInt(3)) { + case 0: + // simple iterator + sets[i] = set; + iterators[i] = new ConstantScoreScorer(null, 0f, anonymizeIterator(new BitDocIdSet(set).iterator())); + break; + case 1: + // bitSet iterator + sets[i] = set; + iterators[i] = new ConstantScoreScorer(null, 0f, new BitDocIdSet(set).iterator()); + break; + default: + // scorer with approximation + final FixedBitSet confirmed = clearRandomBits(set); + sets[i] = confirmed; + final TwoPhaseIterator approximation = approximation(new BitDocIdSet(set).iterator(), confirmed); + iterators[i] = scorer(approximation); + break; } } @@ -232,19 +273,26 @@ public class TestConjunctionDISI extends LuceneTestCase { for (int i = 0; i < numIterators; ++i) { final FixedBitSet set = randomSet(maxDoc); final Scorer newIterator; - if (random().nextBoolean()) { - // simple iterator - sets[i] = set; - newIterator = new ConstantScoreScorer(null, 0f, new BitDocIdSet(set).iterator()); - } else { - // scorer with approximation - final FixedBitSet confirmed = clearRandomBits(set); - sets[i] = confirmed; - final TwoPhaseIterator approximation = approximation(new BitDocIdSet(set).iterator(), confirmed); - newIterator = scorer(approximation); - hasApproximation = true; + switch (random().nextInt(3)) { + case 0: + // simple iterator + sets[i] = set; + newIterator = new ConstantScoreScorer(null, 0f, anonymizeIterator(new BitDocIdSet(set).iterator())); + break; + case 1: + // bitSet iterator + sets[i] = set; + newIterator = new ConstantScoreScorer(null, 0f, new BitDocIdSet(set).iterator()); + break; + default: + // scorer with approximation + final FixedBitSet confirmed = clearRandomBits(set); + sets[i] = confirmed; + final TwoPhaseIterator approximation = approximation(new BitDocIdSet(set).iterator(), confirmed); + newIterator = scorer(approximation); + hasApproximation = true; + break; } - if (conjunction == null) { conjunction = newIterator; } else { diff --git a/lucene/core/src/test/org/apache/lucene/search/TestScorerPerf.java b/lucene/core/src/test/org/apache/lucene/search/TestScorerPerf.java index 741ebe37207..a7c1ba8b0ae 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestScorerPerf.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestScorerPerf.java @@ -53,6 +53,7 @@ public class TestScorerPerf extends LuceneTestCase { iw.close(); r = DirectoryReader.open(d); s = newSearcher(r); + s.setQueryCache(null); } public void createRandomTerms(int nDocs, int nTerms, double power, Directory dir) throws Exception {