diff --git a/lucene/core/src/java/org/apache/lucene/search/DisiWrapper.java b/lucene/core/src/java/org/apache/lucene/search/DisiWrapper.java index 28ba989be62..f2543409f1d 100644 --- a/lucene/core/src/java/org/apache/lucene/search/DisiWrapper.java +++ b/lucene/core/src/java/org/apache/lucene/search/DisiWrapper.java @@ -60,18 +60,6 @@ public class DisiWrapper { } } - // For TermInSetQuery - public DisiWrapper(DocIdSetIterator iterator) { - this.scorer = null; - this.spans = null; - this.iterator = iterator; - this.cost = iterator.cost(); - this.doc = -1; - this.twoPhaseView = null; - this.approximation = iterator; - this.matchCost = 0f; - } - public DisiWrapper(Spans spans) { this.scorer = null; this.spans = spans; diff --git a/lucene/core/src/java/org/apache/lucene/search/TermInSetQuery.java b/lucene/core/src/java/org/apache/lucene/search/TermInSetQuery.java index 5a6676fca90..9b64d379174 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TermInSetQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/TermInSetQuery.java @@ -17,9 +17,12 @@ package org.apache.lucene.search; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.List; +import java.util.Objects; import java.util.Set; import java.util.SortedSet; @@ -30,6 +33,8 @@ import org.apache.lucene.index.PostingsEnum; import org.apache.lucene.index.PrefixCodedTerms; import org.apache.lucene.index.PrefixCodedTerms.TermIterator; import org.apache.lucene.index.Term; +import org.apache.lucene.index.TermContext; +import org.apache.lucene.index.TermState; import org.apache.lucene.index.Terms; import org.apache.lucene.index.TermsEnum; import org.apache.lucene.search.BooleanClause.Occur; @@ -38,7 +43,6 @@ import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRefBuilder; import org.apache.lucene.util.DocIdSetBuilder; -import org.apache.lucene.util.PriorityQueue; import org.apache.lucene.util.RamUsageEstimator; /** @@ -167,6 +171,39 @@ public class TermInSetQuery extends Query implements Accountable { return Collections.emptyList(); } + private static class TermAndState { + final String field; + final TermsEnum termsEnum; + final BytesRef term; + final TermState state; + final int docFreq; + final long totalTermFreq; + + TermAndState(String field, TermsEnum termsEnum) throws IOException { + this.field = field; + this.termsEnum = termsEnum; + this.term = BytesRef.deepCopyOf(termsEnum.term()); + this.state = termsEnum.termState(); + this.docFreq = termsEnum.docFreq(); + this.totalTermFreq = termsEnum.totalTermFreq(); + } + } + + private static class WeightOrDocIdSet { + final Weight weight; + final DocIdSet set; + + WeightOrDocIdSet(Weight weight) { + this.weight = Objects.requireNonNull(weight); + this.set = null; + } + + WeightOrDocIdSet(DocIdSet bitset) { + this.set = bitset; + this.weight = null; + } + } + @Override public Weight createWeight(IndexSearcher searcher, boolean needsScores, float boost) throws IOException { return new ConstantScoreWeight(this, boost) { @@ -179,8 +216,11 @@ public class TermInSetQuery extends Query implements Accountable { // order to protect highlighters } - @Override - public Scorer scorer(LeafReaderContext context) throws IOException { + /** + * On the given leaf context, try to either rewrite to a disjunction if + * there are few matching terms, or build a bitset containing matching docs. + */ + private WeightOrDocIdSet rewrite(LeafReaderContext context) throws IOException { final LeafReader reader = context.reader(); Terms terms = reader.terms(field); @@ -191,49 +231,90 @@ public class TermInSetQuery extends Query implements Accountable { PostingsEnum docs = null; TermIterator iterator = termData.iterator(); - // Here we partition postings based on cost: longer ones will be consumed - // lazily while shorter ones are consumed eagerly into a bitset. Compared to - // putting everything into a bitset, this should help skip over unnecessary doc - // ids in the longer postings lists. This should be especially useful if - // document frequencies have a zipfian distribution. - final PriorityQueue longestPostingsLists = new PriorityQueue(BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD) { - @Override - protected boolean lessThan(PostingsEnum a, PostingsEnum b) { - return a.cost() < b.cost(); - } - }; - DocIdSetBuilder shortestPostingsLists = null; + // We will first try to collect up to 'threshold' terms into 'matchingTerms' + // if there are two many terms, we will fall back to building the 'builder' + final int threshold = Math.min(BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD, BooleanQuery.getMaxClauseCount()); + assert termData.size() > threshold : "Query should have been rewritten"; + List matchingTerms = new ArrayList<>(threshold); + DocIdSetBuilder builder = null; for (BytesRef term = iterator.next(); term != null; term = iterator.next()) { assert field.equals(iterator.field()); if (termsEnum.seekExact(term)) { - docs = termsEnum.postings(docs, PostingsEnum.NONE); - docs = longestPostingsLists.insertWithOverflow(docs); - if (docs != null) { // the pq is full - if (shortestPostingsLists == null) { - shortestPostingsLists = new DocIdSetBuilder(reader.maxDoc()); + if (matchingTerms == null) { + docs = termsEnum.postings(docs, PostingsEnum.NONE); + builder.add(docs); + } else if (matchingTerms.size() < threshold) { + matchingTerms.add(new TermAndState(field, termsEnum)); + } else { + assert matchingTerms.size() == threshold; + builder = new DocIdSetBuilder(reader.maxDoc(), terms); + docs = termsEnum.postings(docs, PostingsEnum.NONE); + builder.add(docs); + for (TermAndState t : matchingTerms) { + t.termsEnum.seekExact(t.term, t.state); + docs = t.termsEnum.postings(docs, PostingsEnum.NONE); + builder.add(docs); } - shortestPostingsLists.add(docs); + matchingTerms = null; } } } + if (matchingTerms != null) { + assert builder == null; + BooleanQuery.Builder bq = new BooleanQuery.Builder(); + for (TermAndState t : matchingTerms) { + final TermContext termContext = new TermContext(searcher.getTopReaderContext()); + termContext.register(t.state, context.ord, t.docFreq, t.totalTermFreq); + bq.add(new TermQuery(new Term(t.field, t.term), termContext), Occur.SHOULD); + } + Query q = new ConstantScoreQuery(bq.build()); + final Weight weight = searcher.rewrite(q).createWeight(searcher, needsScores, score()); + return new WeightOrDocIdSet(weight); + } else { + assert builder != null; + return new WeightOrDocIdSet(builder.build()); + } + } - final int numClauses = longestPostingsLists.size() + (shortestPostingsLists == null ? 0 : 1); - if (numClauses == 0) { + private Scorer scorer(DocIdSet set) throws IOException { + if (set == null) { return null; } + final DocIdSetIterator disi = set.iterator(); + if (disi == null) { + return null; + } + return new ConstantScoreScorer(this, score(), disi); + } - DisiPriorityQueue queue = new DisiPriorityQueue(numClauses); - for (PostingsEnum postings : longestPostingsLists) { - queue.add(new DisiWrapper(postings)); + @Override + public BulkScorer bulkScorer(LeafReaderContext context) throws IOException { + final WeightOrDocIdSet weightOrBitSet = rewrite(context); + if (weightOrBitSet == null) { + return null; + } else if (weightOrBitSet.weight != null) { + return weightOrBitSet.weight.bulkScorer(context); + } else { + final Scorer scorer = scorer(weightOrBitSet.set); + if (scorer == null) { + return null; + } + return new DefaultBulkScorer(scorer); } - if (shortestPostingsLists != null) { - queue.add(new DisiWrapper(shortestPostingsLists.build().iterator())); + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + final WeightOrDocIdSet weightOrBitSet = rewrite(context); + if (weightOrBitSet == null) { + return null; + } else if (weightOrBitSet.weight != null) { + return weightOrBitSet.weight.scorer(context); + } else { + return scorer(weightOrBitSet.set); } - final DocIdSetIterator disi = new DisjunctionDISIApproximation(queue); - return new ConstantScoreScorer(this, boost, disi); } }; } - }