diff --git a/lucene/core/src/java/org/apache/lucene/search/MultiTermQueryConstantScoreWrapper.java b/lucene/core/src/java/org/apache/lucene/search/MultiTermQueryConstantScoreWrapper.java index 048bb0c4823..1fbb6eaffc2 100644 --- a/lucene/core/src/java/org/apache/lucene/search/MultiTermQueryConstantScoreWrapper.java +++ b/lucene/core/src/java/org/apache/lucene/search/MultiTermQueryConstantScoreWrapper.java @@ -20,6 +20,7 @@ package org.apache.lucene.search; import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.Objects; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.PostingsEnum; @@ -60,6 +61,21 @@ final class MultiTermQueryConstantScoreWrapper extends } } + private static class WeightOrBitSet { + final Weight weight; + final BitDocIdSet bitset; + + WeightOrBitSet(Weight weight) { + this.weight = Objects.requireNonNull(weight); + this.bitset = null; + } + + WeightOrBitSet(BitDocIdSet bitset) { + this.bitset = bitset; + this.weight = null; + } + } + protected final Q query; /** @@ -116,18 +132,20 @@ final class MultiTermQueryConstantScoreWrapper extends return termsEnum.next() == null; } - @Override - public Scorer scorer(LeafReaderContext context, Bits acceptDocs) throws IOException { + /** + * On the given leaf context, try to either rewrite to a disjunction if + * there are few terms, or build a bitset containing matching docs. + */ + private WeightOrBitSet rewrite(LeafReaderContext context, Bits acceptDocs) throws IOException { final Terms terms = context.reader().terms(query.field); if (terms == null) { // field does not exist - return null; + return new WeightOrBitSet((BitDocIdSet) null); } final TermsEnum termsEnum = query.getTermsEnum(terms); assert termsEnum != null; - BitDocIdSet.Builder builder = new BitDocIdSet.Builder(context.reader().maxDoc()); PostingsEnum docs = null; final List collectedTerms = new ArrayList<>(); @@ -141,10 +159,11 @@ final class MultiTermQueryConstantScoreWrapper extends } Query q = new ConstantScoreQuery(bq); q.setBoost(score()); - return searcher.rewrite(q).createWeight(searcher, needsScores).scorer(context, acceptDocs); + return new WeightOrBitSet(searcher.rewrite(q).createWeight(searcher, needsScores)); } // Too many terms: go back to the terms we already collected and start building the bit set + BitDocIdSet.Builder builder = new BitDocIdSet.Builder(context.reader().maxDoc()); if (collectedTerms.isEmpty() == false) { TermsEnum termsEnum2 = terms.iterator(); for (TermAndState t : collectedTerms) { @@ -160,7 +179,10 @@ final class MultiTermQueryConstantScoreWrapper extends builder.or(docs); } while (termsEnum.next() != null); - final BitDocIdSet set = builder.build(); + return new WeightOrBitSet(builder.build()); + } + + private Scorer scorer(BitDocIdSet set) { if (set == null) { return null; } @@ -170,6 +192,30 @@ final class MultiTermQueryConstantScoreWrapper extends } return new ConstantScoreScorer(this, score(), disi); } + + @Override + public BulkScorer bulkScorer(LeafReaderContext context, Bits acceptDocs) throws IOException { + final WeightOrBitSet weightOrBitSet = rewrite(context, acceptDocs); + if (weightOrBitSet.weight != null) { + return weightOrBitSet.weight.bulkScorer(context, acceptDocs); + } else { + final Scorer scorer = scorer(weightOrBitSet.bitset); + if (scorer == null) { + return null; + } + return new DefaultBulkScorer(scorer); + } + } + + @Override + public Scorer scorer(LeafReaderContext context, Bits acceptDocs) throws IOException { + final WeightOrBitSet weightOrBitSet = rewrite(context, acceptDocs); + if (weightOrBitSet.weight != null) { + return weightOrBitSet.weight.scorer(context, acceptDocs); + } else { + return scorer(weightOrBitSet.bitset); + } + } }; } }