LUCENE-7339: Specialize conjunction with bit sets.

This commit is contained in:
Adrien Grand 2016-06-22 08:02:17 +02:00
parent a3fc7efbcc
commit 60dcef6c61
6 changed files with 227 additions and 42 deletions

View File

@ -56,7 +56,7 @@ Improvements
Optimizations Optimizations
* LUCENE-7330: Speed up conjunction queries. (Adrien Grand) * LUCENE-7330, LUCENE-7339: Speed up conjunction queries. (Adrien Grand)
Other Other

View File

@ -19,11 +19,15 @@ package org.apache.lucene.search;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.Comparator; import java.util.Comparator;
import java.util.List; import java.util.List;
import org.apache.lucene.search.spans.Spans; import org.apache.lucene.search.spans.Spans;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.CollectionUtil; import org.apache.lucene.util.CollectionUtil;
/** A conjunction of DocIdSetIterators. /** A conjunction of DocIdSetIterators.
@ -47,11 +51,7 @@ public final class ConjunctionDISI extends DocIdSetIterator {
addScorer(scorer, allIterators, twoPhaseIterators); addScorer(scorer, allIterators, twoPhaseIterators);
} }
DocIdSetIterator iterator = new ConjunctionDISI(allIterators); return createConjunction(allIterators, twoPhaseIterators);
if (twoPhaseIterators.isEmpty() == false) {
iterator = TwoPhaseIterator.asDocIdSetIterator(new ConjunctionTwoPhaseIterator(iterator, twoPhaseIterators));
}
return iterator;
} }
/** Create a conjunction over the provided DocIdSetIterators. Note that the /** Create a conjunction over the provided DocIdSetIterators. Note that the
@ -68,11 +68,7 @@ public final class ConjunctionDISI extends DocIdSetIterator {
addIterator(iterator, allIterators, twoPhaseIterators); addIterator(iterator, allIterators, twoPhaseIterators);
} }
DocIdSetIterator iterator = new ConjunctionDISI(allIterators); return createConjunction(allIterators, twoPhaseIterators);
if (twoPhaseIterators.isEmpty() == false) {
iterator = TwoPhaseIterator.asDocIdSetIterator(new ConjunctionTwoPhaseIterator(iterator, twoPhaseIterators));
}
return iterator;
} }
/** Create a conjunction over the provided {@link Spans}. Note that the /** Create a conjunction over the provided {@link Spans}. Note that the
@ -89,11 +85,7 @@ public final class ConjunctionDISI extends DocIdSetIterator {
addSpans(spans, allIterators, twoPhaseIterators); addSpans(spans, allIterators, twoPhaseIterators);
} }
DocIdSetIterator iterator = new ConjunctionDISI(allIterators); return createConjunction(allIterators, twoPhaseIterators);
if (twoPhaseIterators.isEmpty() == false) {
iterator = TwoPhaseIterator.asDocIdSetIterator(new ConjunctionTwoPhaseIterator(iterator, twoPhaseIterators));
}
return iterator;
} }
/** Adds the scorer, possibly splitting up into two phases or collapsing if it is another conjunction */ /** Adds the scorer, possibly splitting up into two phases or collapsing if it is another conjunction */
@ -127,6 +119,10 @@ public final class ConjunctionDISI extends DocIdSetIterator {
allIterators.add(conjunction.lead1); allIterators.add(conjunction.lead1);
allIterators.add(conjunction.lead2); allIterators.add(conjunction.lead2);
Collections.addAll(allIterators, conjunction.others); Collections.addAll(allIterators, conjunction.others);
} else if (disi.getClass() == BitSetConjunctionDISI.class) {
BitSetConjunctionDISI conjunction = (BitSetConjunctionDISI) disi;
allIterators.add(conjunction.lead);
Collections.addAll(allIterators, conjunction.bitSetIterators);
} else { } else {
allIterators.add(disi); allIterators.add(disi);
} }
@ -141,6 +137,41 @@ public final class ConjunctionDISI extends DocIdSetIterator {
} }
} }
private static DocIdSetIterator createConjunction(
List<DocIdSetIterator> allIterators,
List<TwoPhaseIterator> twoPhaseIterators) {
long minCost = allIterators.stream().mapToLong(DocIdSetIterator::cost).min().getAsLong();
List<BitSetIterator> bitSetIterators = new ArrayList<>();
List<DocIdSetIterator> iterators = new ArrayList<>();
for (DocIdSetIterator iterator : allIterators) {
if (iterator.cost() > minCost && iterator instanceof BitSetIterator) {
// we put all bitset iterators into bitSetIterators
// except if they have the minimum cost, since we need
// them to lead the iteration in that case
bitSetIterators.add((BitSetIterator) iterator);
} else {
iterators.add(iterator);
}
}
DocIdSetIterator disi;
if (iterators.size() == 1) {
disi = iterators.get(0);
} else {
disi = new ConjunctionDISI(iterators);
}
if (bitSetIterators.size() > 0) {
disi = new BitSetConjunctionDISI(disi, bitSetIterators);
}
if (twoPhaseIterators.isEmpty() == false) {
disi = TwoPhaseIterator.asDocIdSetIterator(new ConjunctionTwoPhaseIterator(disi, twoPhaseIterators));
}
return disi;
}
final DocIdSetIterator lead1, lead2; final DocIdSetIterator lead1, lead2;
final DocIdSetIterator[] others; final DocIdSetIterator[] others;
@ -214,6 +245,69 @@ public final class ConjunctionDISI extends DocIdSetIterator {
return lead1.cost(); // overestimate return lead1.cost(); // overestimate
} }
/** Conjunction between a {@link DocIdSetIterator} and one or more {@link BitSetIterator}s. */
private static class BitSetConjunctionDISI extends DocIdSetIterator {
private final DocIdSetIterator lead;
private final BitSetIterator[] bitSetIterators;
private final BitSet[] bitSets;
private final int minLength;
BitSetConjunctionDISI(DocIdSetIterator lead, Collection<BitSetIterator> bitSetIterators) {
this.lead = lead;
assert bitSetIterators.size() > 0;
this.bitSetIterators = bitSetIterators.toArray(new BitSetIterator[0]);
// Put the least costly iterators first so that we exit as soon as possible
ArrayUtil.timSort(this.bitSetIterators, (a, b) -> Long.compare(a.cost(), b.cost()));
this.bitSets = new BitSet[this.bitSetIterators.length];
int minLen = Integer.MAX_VALUE;
for (int i = 0; i < this.bitSetIterators.length; ++i) {
BitSet bitSet = this.bitSetIterators[i].getBitSet();
this.bitSets[i] = bitSet;
minLen = Math.min(minLen, bitSet.length());
}
this.minLength = minLen;
}
@Override
public int docID() {
return lead.docID();
}
@Override
public int nextDoc() throws IOException {
return doNext(lead.nextDoc());
}
@Override
public int advance(int target) throws IOException {
return doNext(lead.advance(target));
}
private int doNext(int doc) throws IOException {
advanceLead: for (;; doc = lead.nextDoc()) {
if (doc >= minLength) {
return NO_MORE_DOCS;
}
for (BitSet bitSet : bitSets) {
if (bitSet.get(doc) == false) {
continue advanceLead;
}
}
for (BitSetIterator iterator : bitSetIterators) {
iterator.setDocId(doc);
}
return doc;
}
}
@Override
public long cost() {
return lead.cost();
}
}
/** /**
* {@link TwoPhaseIterator} implementing a conjunction. * {@link TwoPhaseIterator} implementing a conjunction.
*/ */

View File

@ -40,6 +40,8 @@ import org.apache.lucene.index.Term;
import org.apache.lucene.index.TieredMergePolicy; import org.apache.lucene.index.TieredMergePolicy;
import org.apache.lucene.util.Accountable; import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.Accountables; import org.apache.lucene.util.Accountables;
import org.apache.lucene.util.BitDocIdSet;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.RamUsageEstimator; import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.RoaringDocIdSet; import org.apache.lucene.util.RoaringDocIdSet;
@ -502,9 +504,39 @@ public class LRUQueryCache implements QueryCache, Accountable {
} }
/** /**
* Default cache implementation: uses {@link RoaringDocIdSet}. * Default cache implementation: uses {@link RoaringDocIdSet} for sets that
* have a density &lt; 1% and a {@link BitDocIdSet} over a {@link FixedBitSet}
* otherwise.
*/ */
protected DocIdSet cacheImpl(BulkScorer scorer, int maxDoc) throws IOException { protected DocIdSet cacheImpl(BulkScorer scorer, int maxDoc) throws IOException {
if (scorer.cost() * 100 >= maxDoc) {
// FixedBitSet is faster for dense sets and will enable the random-access
// optimization in ConjunctionDISI
return cacheIntoBitSet(scorer, maxDoc);
} else {
return cacheIntoRoaringDocIdSet(scorer, maxDoc);
}
}
private static DocIdSet cacheIntoBitSet(BulkScorer scorer, int maxDoc) throws IOException {
final FixedBitSet bitSet = new FixedBitSet(maxDoc);
long cost[] = new long[1];
scorer.score(new LeafCollector() {
@Override
public void setScorer(Scorer scorer) throws IOException {}
@Override
public void collect(int doc) throws IOException {
cost[0]++;
bitSet.set(doc);
}
}, null);
return new BitDocIdSet(bitSet, cost[0]);
}
private static DocIdSet cacheIntoRoaringDocIdSet(BulkScorer scorer, int maxDoc) throws IOException {
RoaringDocIdSet.Builder builder = new RoaringDocIdSet.Builder(maxDoc); RoaringDocIdSet.Builder builder = new RoaringDocIdSet.Builder(maxDoc);
scorer.score(new LeafCollector() { scorer.score(new LeafCollector() {

View File

@ -62,11 +62,21 @@ public class BitSetIterator extends DocIdSetIterator {
this.cost = cost; this.cost = cost;
} }
/** Return the wrapped {@link BitSet}. */
public BitSet getBitSet() {
return bits;
}
@Override @Override
public int docID() { public int docID() {
return doc; return doc;
} }
/** Set the current doc id that this iterator is on. */
public void setDocId(int docId) {
this.doc = docId;
}
@Override @Override
public int nextDoc() { public int nextDoc() {
return advance(doc + 1); return advance(doc + 1);

View File

@ -30,12 +30,18 @@ import org.apache.lucene.util.TestUtil;
public class TestConjunctionDISI extends LuceneTestCase { public class TestConjunctionDISI extends LuceneTestCase {
private static TwoPhaseIterator approximation(final DocIdSetIterator iterator, final FixedBitSet confirmed) { private static TwoPhaseIterator approximation(DocIdSetIterator iterator, final FixedBitSet confirmed) {
return new TwoPhaseIterator(iterator) { DocIdSetIterator approximation;
if (random().nextBoolean()) {
approximation = anonymizeIterator(iterator);
} else {
approximation = iterator;
}
return new TwoPhaseIterator(approximation) {
@Override @Override
public boolean matches() throws IOException { public boolean matches() throws IOException {
return confirmed.get(iterator.docID()); return confirmed.get(approximation.docID());
} }
@Override @Override
@ -45,6 +51,33 @@ public class TestConjunctionDISI extends LuceneTestCase {
}; };
} }
/** Return an anonym class so that ConjunctionDISI cannot optimize it
* like it does eg. for BitSetIterators. */
private static DocIdSetIterator anonymizeIterator(DocIdSetIterator it) {
return new DocIdSetIterator() {
@Override
public int nextDoc() throws IOException {
return it.nextDoc();
}
@Override
public int docID() {
return it.docID();
}
@Override
public long cost() {
return it.docID();
}
@Override
public int advance(int target) throws IOException {
return it.advance(target);
}
};
}
private static Scorer scorer(TwoPhaseIterator twoPhaseIterator) { private static Scorer scorer(TwoPhaseIterator twoPhaseIterator) {
return scorer(TwoPhaseIterator.asDocIdSetIterator(twoPhaseIterator), twoPhaseIterator); return scorer(TwoPhaseIterator.asDocIdSetIterator(twoPhaseIterator), twoPhaseIterator);
} }
@ -168,16 +201,24 @@ public class TestConjunctionDISI extends LuceneTestCase {
final Scorer[] iterators = new Scorer[numIterators]; final Scorer[] iterators = new Scorer[numIterators];
for (int i = 0; i < iterators.length; ++i) { for (int i = 0; i < iterators.length; ++i) {
final FixedBitSet set = randomSet(maxDoc); final FixedBitSet set = randomSet(maxDoc);
if (random().nextBoolean()) { switch (random().nextInt(3)) {
case 0:
// simple iterator // simple iterator
sets[i] = set; sets[i] = set;
iterators[i] = new ConstantScoreScorer(null, 0f, anonymizeIterator(new BitDocIdSet(set).iterator()));
break;
case 1:
// bitSet iterator
sets[i] = set;
iterators[i] = new ConstantScoreScorer(null, 0f, new BitDocIdSet(set).iterator()); iterators[i] = new ConstantScoreScorer(null, 0f, new BitDocIdSet(set).iterator());
} else { break;
default:
// scorer with approximation // scorer with approximation
final FixedBitSet confirmed = clearRandomBits(set); final FixedBitSet confirmed = clearRandomBits(set);
sets[i] = confirmed; sets[i] = confirmed;
final TwoPhaseIterator approximation = approximation(new BitDocIdSet(set).iterator(), confirmed); final TwoPhaseIterator approximation = approximation(new BitDocIdSet(set).iterator(), confirmed);
iterators[i] = scorer(approximation); iterators[i] = scorer(approximation);
break;
} }
} }
@ -232,19 +273,26 @@ public class TestConjunctionDISI extends LuceneTestCase {
for (int i = 0; i < numIterators; ++i) { for (int i = 0; i < numIterators; ++i) {
final FixedBitSet set = randomSet(maxDoc); final FixedBitSet set = randomSet(maxDoc);
final Scorer newIterator; final Scorer newIterator;
if (random().nextBoolean()) { switch (random().nextInt(3)) {
case 0:
// simple iterator // simple iterator
sets[i] = set; sets[i] = set;
newIterator = new ConstantScoreScorer(null, 0f, anonymizeIterator(new BitDocIdSet(set).iterator()));
break;
case 1:
// bitSet iterator
sets[i] = set;
newIterator = new ConstantScoreScorer(null, 0f, new BitDocIdSet(set).iterator()); newIterator = new ConstantScoreScorer(null, 0f, new BitDocIdSet(set).iterator());
} else { break;
default:
// scorer with approximation // scorer with approximation
final FixedBitSet confirmed = clearRandomBits(set); final FixedBitSet confirmed = clearRandomBits(set);
sets[i] = confirmed; sets[i] = confirmed;
final TwoPhaseIterator approximation = approximation(new BitDocIdSet(set).iterator(), confirmed); final TwoPhaseIterator approximation = approximation(new BitDocIdSet(set).iterator(), confirmed);
newIterator = scorer(approximation); newIterator = scorer(approximation);
hasApproximation = true; hasApproximation = true;
break;
} }
if (conjunction == null) { if (conjunction == null) {
conjunction = newIterator; conjunction = newIterator;
} else { } else {

View File

@ -53,6 +53,7 @@ public class TestScorerPerf extends LuceneTestCase {
iw.close(); iw.close();
r = DirectoryReader.open(d); r = DirectoryReader.open(d);
s = newSearcher(r); s = newSearcher(r);
s.setQueryCache(null);
} }
public void createRandomTerms(int nDocs, int nTerms, double power, Directory dir) throws Exception { public void createRandomTerms(int nDocs, int nTerms, double power, Directory dir) throws Exception {