mirror of https://github.com/apache/lucene.git
Speed up advancing on the disjunction iterator. (#14052)
Currently, the disjunction iterator puts all clauses in a heap in order to be able to merge doc IDs in a streaming fashion. This is a good approach for exhaustive evaluation, when only one clause moves to a different doc ID on average and the per-iteration cost is in the order of O(log(N)) where N is the number of clauses. However, if a selective filter is applied, this could cause many clauses to move to a different doc ID. In the worst-case scenario, all clauses could move to a different doc ID and the cost of maintaiting heap invariants could grow to O(N * log(N)) (every clause introduces a O(log(N)) cost). With many clauses, this is much higher than the cost of checking all clauses sequentially: O(N). To protect from this reordering overhead, DisjunctionDISIApproximation now only puts the cheapest clauses in a heap in a way that tries to achieve up to 1.5 clauses moving to a different doc ID on average. More expensive clauses are checked linearly.
This commit is contained in:
parent
a8d8d6b3d9
commit
bc341f2b3e
|
@ -47,7 +47,8 @@ Improvements
|
|||
|
||||
Optimizations
|
||||
---------------------
|
||||
(No changes)
|
||||
|
||||
* GITHUB#14052: Speed up DisjunctionDISIApproximation#advance. (Adrien Grand)
|
||||
|
||||
Bug Fixes
|
||||
---------------------
|
||||
|
|
|
@ -28,7 +28,6 @@ import org.apache.lucene.index.Terms;
|
|||
import org.apache.lucene.index.TermsEnum;
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.IOSupplier;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
|
||||
/**
|
||||
|
@ -151,7 +150,8 @@ abstract class AbstractMultiTermQueryConstantScoreWrapper<Q extends MultiTermQue
|
|||
int fieldDocCount,
|
||||
Terms terms,
|
||||
TermsEnum termsEnum,
|
||||
List<TermAndState> collectedTerms)
|
||||
List<TermAndState> collectedTerms,
|
||||
long leadCost)
|
||||
throws IOException;
|
||||
|
||||
private WeightOrDocIdSetIterator rewriteAsBooleanQuery(
|
||||
|
@ -247,21 +247,22 @@ abstract class AbstractMultiTermQueryConstantScoreWrapper<Q extends MultiTermQue
|
|||
cost = estimateCost(terms, q.getTermsCount());
|
||||
}
|
||||
|
||||
IOSupplier<WeightOrDocIdSetIterator> weightOrIteratorSupplier =
|
||||
() -> {
|
||||
IOLongFunction<WeightOrDocIdSetIterator> weightOrIteratorSupplier =
|
||||
leadCost -> {
|
||||
if (collectResult) {
|
||||
return rewriteAsBooleanQuery(context, collectedTerms);
|
||||
} else {
|
||||
// Too many terms to rewrite as a simple bq.
|
||||
// Invoke rewriteInner logic to handle rewriting:
|
||||
return rewriteInner(context, fieldDocCount, terms, termsEnum, collectedTerms);
|
||||
return rewriteInner(
|
||||
context, fieldDocCount, terms, termsEnum, collectedTerms, leadCost);
|
||||
}
|
||||
};
|
||||
|
||||
return new ScorerSupplier() {
|
||||
@Override
|
||||
public Scorer get(long leadCost) throws IOException {
|
||||
WeightOrDocIdSetIterator weightOrIterator = weightOrIteratorSupplier.get();
|
||||
WeightOrDocIdSetIterator weightOrIterator = weightOrIteratorSupplier.apply(leadCost);
|
||||
final Scorer scorer;
|
||||
if (weightOrIterator == null) {
|
||||
scorer = null;
|
||||
|
@ -281,7 +282,8 @@ abstract class AbstractMultiTermQueryConstantScoreWrapper<Q extends MultiTermQue
|
|||
|
||||
@Override
|
||||
public BulkScorer bulkScorer() throws IOException {
|
||||
WeightOrDocIdSetIterator weightOrIterator = weightOrIteratorSupplier.get();
|
||||
WeightOrDocIdSetIterator weightOrIterator =
|
||||
weightOrIteratorSupplier.apply(Long.MAX_VALUE);
|
||||
final BulkScorer bulkScorer;
|
||||
if (weightOrIterator == null) {
|
||||
bulkScorer = null;
|
||||
|
@ -311,6 +313,10 @@ abstract class AbstractMultiTermQueryConstantScoreWrapper<Q extends MultiTermQue
|
|||
};
|
||||
}
|
||||
|
||||
private static interface IOLongFunction<T> {
|
||||
T apply(long arg) throws IOException;
|
||||
}
|
||||
|
||||
private static long estimateCost(Terms terms, long queryTermsCount) throws IOException {
|
||||
// Estimate the cost. If the MTQ can provide its term count, we can do a better job
|
||||
// estimating.
|
||||
|
|
|
@ -237,7 +237,8 @@ final class BooleanScorerSupplier extends ScorerSupplier {
|
|||
Scorer prohibitedScorer =
|
||||
prohibited.size() == 1
|
||||
? prohibited.get(0)
|
||||
: new DisjunctionSumScorer(prohibited, ScoreMode.COMPLETE_NO_SCORES);
|
||||
: new DisjunctionSumScorer(
|
||||
prohibited, ScoreMode.COMPLETE_NO_SCORES, positiveScorerCost);
|
||||
return new ReqExclBulkScorer(positiveScorer, prohibitedScorer);
|
||||
}
|
||||
}
|
||||
|
@ -509,7 +510,7 @@ final class BooleanScorerSupplier extends ScorerSupplier {
|
|||
if ((scoreMode == ScoreMode.TOP_SCORES && topLevelScoringClause) || minShouldMatch > 1) {
|
||||
return new WANDScorer(optionalScorers, minShouldMatch, scoreMode, leadCost);
|
||||
} else {
|
||||
return new DisjunctionSumScorer(optionalScorers, scoreMode);
|
||||
return new DisjunctionSumScorer(optionalScorers, scoreMode, leadCost);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,6 +17,10 @@
|
|||
package org.apache.lucene.search;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* A {@link DocIdSetIterator} which is a disjunction of the approximations of the provided
|
||||
|
@ -24,18 +28,75 @@ import java.io.IOException;
|
|||
*
|
||||
* @lucene.internal
|
||||
*/
|
||||
public class DisjunctionDISIApproximation extends DocIdSetIterator {
|
||||
public final class DisjunctionDISIApproximation extends DocIdSetIterator {
|
||||
|
||||
final DisiPriorityQueue subIterators;
|
||||
final long cost;
|
||||
public static DisjunctionDISIApproximation of(
|
||||
Collection<DisiWrapper> subIterators, long leadCost) {
|
||||
|
||||
return new DisjunctionDISIApproximation(subIterators, leadCost);
|
||||
}
|
||||
|
||||
// Heap of iterators that lead iteration.
|
||||
private final DisiPriorityQueue leadIterators;
|
||||
// List of iterators that will likely advance on every call to nextDoc() / advance()
|
||||
private final DisiWrapper[] otherIterators;
|
||||
private final long cost;
|
||||
private DisiWrapper leadTop;
|
||||
private int minOtherDoc;
|
||||
|
||||
public DisjunctionDISIApproximation(Collection<DisiWrapper> subIterators, long leadCost) {
|
||||
// Using a heap to store disjunctive clauses is great for exhaustive evaluation, when a single
|
||||
// clause needs to move through the heap on every iteration on average. However, when
|
||||
// intersecting with a selective filter, it is possible that all clauses need advancing, which
|
||||
// makes the reordering cost scale in O(N * log(N)) per advance() call when checking clauses
|
||||
// linearly would scale in O(N).
|
||||
// To protect against this reordering overhead, we try to have 1.5 clauses or less that advance
|
||||
// on every advance() call by only putting clauses into the heap as long as Σ min(1, cost /
|
||||
// leadCost) <= 1.5, or Σ min(leadCost, cost) <= 1.5 * leadCost. Other clauses are checked
|
||||
// linearly.
|
||||
|
||||
List<DisiWrapper> wrappers = new ArrayList<>(subIterators);
|
||||
// Sort by descending cost.
|
||||
wrappers.sort(Comparator.<DisiWrapper>comparingLong(w -> w.cost).reversed());
|
||||
|
||||
leadIterators = new DisiPriorityQueue(subIterators.size());
|
||||
|
||||
long reorderThreshold = leadCost + (leadCost >> 1);
|
||||
if (reorderThreshold < 0) { // overflow
|
||||
reorderThreshold = Long.MAX_VALUE;
|
||||
}
|
||||
long reorderCost = 0;
|
||||
while (wrappers.isEmpty() == false) {
|
||||
DisiWrapper last = wrappers.getLast();
|
||||
long inc = Math.min(last.cost, leadCost);
|
||||
if (reorderCost + inc < 0 || reorderCost + inc > reorderThreshold) {
|
||||
break;
|
||||
}
|
||||
leadIterators.add(wrappers.removeLast());
|
||||
reorderCost += inc;
|
||||
}
|
||||
|
||||
// Make leadIterators not empty. This helps save conditionals in the implementation which are
|
||||
// rarely tested.
|
||||
if (leadIterators.size() == 0) {
|
||||
leadIterators.add(wrappers.removeLast());
|
||||
}
|
||||
|
||||
otherIterators = wrappers.toArray(DisiWrapper[]::new);
|
||||
|
||||
public DisjunctionDISIApproximation(DisiPriorityQueue subIterators) {
|
||||
this.subIterators = subIterators;
|
||||
long cost = 0;
|
||||
for (DisiWrapper w : subIterators) {
|
||||
for (DisiWrapper w : leadIterators) {
|
||||
cost += w.cost;
|
||||
}
|
||||
for (DisiWrapper w : otherIterators) {
|
||||
cost += w.cost;
|
||||
}
|
||||
this.cost = cost;
|
||||
minOtherDoc = Integer.MAX_VALUE;
|
||||
for (DisiWrapper w : otherIterators) {
|
||||
minOtherDoc = Math.min(minOtherDoc, w.doc);
|
||||
}
|
||||
leadTop = leadIterators.top();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -45,29 +106,62 @@ public class DisjunctionDISIApproximation extends DocIdSetIterator {
|
|||
|
||||
@Override
|
||||
public int docID() {
|
||||
return subIterators.top().doc;
|
||||
return Math.min(minOtherDoc, leadTop.doc);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
DisiWrapper top = subIterators.top();
|
||||
final int doc = top.doc;
|
||||
do {
|
||||
top.doc = top.approximation.nextDoc();
|
||||
top = subIterators.updateTop();
|
||||
} while (top.doc == doc);
|
||||
|
||||
return top.doc;
|
||||
if (leadTop.doc < minOtherDoc) {
|
||||
int curDoc = leadTop.doc;
|
||||
do {
|
||||
leadTop.doc = leadTop.approximation.nextDoc();
|
||||
leadTop = leadIterators.updateTop();
|
||||
} while (leadTop.doc == curDoc);
|
||||
return Math.min(leadTop.doc, minOtherDoc);
|
||||
} else {
|
||||
return advance(minOtherDoc + 1);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
DisiWrapper top = subIterators.top();
|
||||
do {
|
||||
top.doc = top.approximation.advance(target);
|
||||
top = subIterators.updateTop();
|
||||
} while (top.doc < target);
|
||||
while (leadTop.doc < target) {
|
||||
leadTop.doc = leadTop.approximation.advance(target);
|
||||
leadTop = leadIterators.updateTop();
|
||||
}
|
||||
|
||||
return top.doc;
|
||||
minOtherDoc = Integer.MAX_VALUE;
|
||||
for (DisiWrapper w : otherIterators) {
|
||||
if (w.doc < target) {
|
||||
w.doc = w.approximation.advance(target);
|
||||
}
|
||||
minOtherDoc = Math.min(minOtherDoc, w.doc);
|
||||
}
|
||||
|
||||
return Math.min(leadTop.doc, minOtherDoc);
|
||||
}
|
||||
|
||||
/** Return the linked list of iterators positioned on the current doc. */
|
||||
public DisiWrapper topList() {
|
||||
if (leadTop.doc < minOtherDoc) {
|
||||
return leadIterators.topList();
|
||||
} else {
|
||||
return computeTopList();
|
||||
}
|
||||
}
|
||||
|
||||
private DisiWrapper computeTopList() {
|
||||
assert leadTop.doc >= minOtherDoc;
|
||||
DisiWrapper topList = null;
|
||||
if (leadTop.doc == minOtherDoc) {
|
||||
topList = leadIterators.topList();
|
||||
}
|
||||
for (DisiWrapper w : otherIterators) {
|
||||
if (w.doc == minOtherDoc) {
|
||||
w.next = topList;
|
||||
topList = w;
|
||||
}
|
||||
}
|
||||
return topList;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -155,7 +155,7 @@ public final class DisjunctionMaxQuery extends Query implements Iterable<Query>
|
|||
for (ScorerSupplier ss : scorerSuppliers) {
|
||||
scorers.add(ss.get(leadCost));
|
||||
}
|
||||
return new DisjunctionMaxScorer(tieBreakerMultiplier, scorers, scoreMode);
|
||||
return new DisjunctionMaxScorer(tieBreakerMultiplier, scorers, scoreMode, leadCost);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -40,9 +40,10 @@ final class DisjunctionMaxScorer extends DisjunctionScorer {
|
|||
* as they are summed into the result.
|
||||
* @param subScorers The sub scorers this Scorer should iterate on
|
||||
*/
|
||||
DisjunctionMaxScorer(float tieBreakerMultiplier, List<Scorer> subScorers, ScoreMode scoreMode)
|
||||
DisjunctionMaxScorer(
|
||||
float tieBreakerMultiplier, List<Scorer> subScorers, ScoreMode scoreMode, long leadCost)
|
||||
throws IOException {
|
||||
super(subScorers, scoreMode);
|
||||
super(subScorers, scoreMode, leadCost);
|
||||
this.subScorers = subScorers;
|
||||
this.tieBreakerMultiplier = tieBreakerMultiplier;
|
||||
if (tieBreakerMultiplier < 0 || tieBreakerMultiplier > 1) {
|
||||
|
|
|
@ -25,37 +25,34 @@ import org.apache.lucene.util.PriorityQueue;
|
|||
/** Base class for Scorers that score disjunctions. */
|
||||
abstract class DisjunctionScorer extends Scorer {
|
||||
|
||||
private final int numClauses;
|
||||
private final boolean needsScores;
|
||||
|
||||
private final DisiPriorityQueue subScorers;
|
||||
private final DocIdSetIterator approximation;
|
||||
private final DisjunctionDISIApproximation approximation;
|
||||
private final TwoPhase twoPhase;
|
||||
|
||||
protected DisjunctionScorer(List<Scorer> subScorers, ScoreMode scoreMode) throws IOException {
|
||||
protected DisjunctionScorer(List<Scorer> subScorers, ScoreMode scoreMode, long leadCost)
|
||||
throws IOException {
|
||||
if (subScorers.size() <= 1) {
|
||||
throw new IllegalArgumentException("There must be at least 2 subScorers");
|
||||
}
|
||||
this.subScorers = new DisiPriorityQueue(subScorers.size());
|
||||
for (Scorer scorer : subScorers) {
|
||||
final DisiWrapper w = new DisiWrapper(scorer, false);
|
||||
this.subScorers.add(w);
|
||||
}
|
||||
this.numClauses = subScorers.size();
|
||||
this.needsScores = scoreMode != ScoreMode.COMPLETE_NO_SCORES;
|
||||
this.approximation = new DisjunctionDISIApproximation(this.subScorers);
|
||||
|
||||
boolean hasApproximation = false;
|
||||
float sumMatchCost = 0;
|
||||
long sumApproxCost = 0;
|
||||
// Compute matchCost as the average over the matchCost of the subScorers.
|
||||
// This is weighted by the cost, which is an expected number of matching documents.
|
||||
for (DisiWrapper w : this.subScorers) {
|
||||
List<DisiWrapper> wrappers = new ArrayList<>();
|
||||
for (Scorer scorer : subScorers) {
|
||||
DisiWrapper w = new DisiWrapper(scorer, false);
|
||||
long costWeight = (w.cost <= 1) ? 1 : w.cost;
|
||||
sumApproxCost += costWeight;
|
||||
if (w.twoPhaseView != null) {
|
||||
hasApproximation = true;
|
||||
sumMatchCost += w.matchCost * costWeight;
|
||||
}
|
||||
wrappers.add(w);
|
||||
}
|
||||
this.approximation = new DisjunctionDISIApproximation(wrappers, leadCost);
|
||||
|
||||
if (hasApproximation == false) { // no sub scorer supports approximations
|
||||
twoPhase = null;
|
||||
|
@ -91,7 +88,7 @@ abstract class DisjunctionScorer extends Scorer {
|
|||
super(approximation);
|
||||
this.matchCost = matchCost;
|
||||
unverifiedMatches =
|
||||
new PriorityQueue<DisiWrapper>(DisjunctionScorer.this.subScorers.size()) {
|
||||
new PriorityQueue<DisiWrapper>(numClauses) {
|
||||
@Override
|
||||
protected boolean lessThan(DisiWrapper a, DisiWrapper b) {
|
||||
return a.matchCost < b.matchCost;
|
||||
|
@ -116,7 +113,7 @@ abstract class DisjunctionScorer extends Scorer {
|
|||
verifiedMatches = null;
|
||||
unverifiedMatches.clear();
|
||||
|
||||
for (DisiWrapper w = subScorers.topList(); w != null; ) {
|
||||
for (DisiWrapper w = DisjunctionScorer.this.approximation.topList(); w != null; ) {
|
||||
DisiWrapper next = w.next;
|
||||
|
||||
if (w.twoPhaseView == null) {
|
||||
|
@ -160,12 +157,12 @@ abstract class DisjunctionScorer extends Scorer {
|
|||
|
||||
@Override
|
||||
public final int docID() {
|
||||
return subScorers.top().doc;
|
||||
return approximation.docID();
|
||||
}
|
||||
|
||||
DisiWrapper getSubMatches() throws IOException {
|
||||
if (twoPhase == null) {
|
||||
return subScorers.topList();
|
||||
return approximation.topList();
|
||||
} else {
|
||||
return twoPhase.getSubMatches();
|
||||
}
|
||||
|
|
|
@ -30,8 +30,9 @@ final class DisjunctionSumScorer extends DisjunctionScorer {
|
|||
*
|
||||
* @param subScorers Array of at least two subscorers.
|
||||
*/
|
||||
DisjunctionSumScorer(List<Scorer> subScorers, ScoreMode scoreMode) throws IOException {
|
||||
super(subScorers, scoreMode);
|
||||
DisjunctionSumScorer(List<Scorer> subScorers, ScoreMode scoreMode, long leadCost)
|
||||
throws IOException {
|
||||
super(subScorers, scoreMode, leadCost);
|
||||
this.scorers = subScorers;
|
||||
}
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.apache.lucene.search;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
|
@ -27,18 +28,17 @@ import java.util.List;
|
|||
public abstract class IndriDisjunctionScorer extends IndriScorer {
|
||||
|
||||
private final List<Scorer> subScorersList;
|
||||
private final DisiPriorityQueue subScorers;
|
||||
private final DocIdSetIterator approximation;
|
||||
|
||||
protected IndriDisjunctionScorer(List<Scorer> subScorersList, ScoreMode scoreMode, float boost) {
|
||||
super(boost);
|
||||
this.subScorersList = subScorersList;
|
||||
this.subScorers = new DisiPriorityQueue(subScorersList.size());
|
||||
List<DisiWrapper> wrappers = new ArrayList<>();
|
||||
for (Scorer scorer : subScorersList) {
|
||||
final DisiWrapper w = new DisiWrapper(scorer, false);
|
||||
this.subScorers.add(w);
|
||||
wrappers.add(w);
|
||||
}
|
||||
this.approximation = new DisjunctionDISIApproximation(this.subScorers);
|
||||
this.approximation = new DisjunctionDISIApproximation(wrappers, Long.MAX_VALUE);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -71,6 +71,6 @@ public abstract class IndriDisjunctionScorer extends IndriScorer {
|
|||
|
||||
@Override
|
||||
public int docID() {
|
||||
return subScorers.top().doc;
|
||||
return approximation.docID();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.apache.lucene.search;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.PostingsEnum;
|
||||
|
@ -52,7 +53,8 @@ final class MultiTermQueryConstantScoreBlendedWrapper<Q extends MultiTermQuery>
|
|||
int fieldDocCount,
|
||||
Terms terms,
|
||||
TermsEnum termsEnum,
|
||||
List<TermAndState> collectedTerms)
|
||||
List<TermAndState> collectedTerms,
|
||||
long leadCost)
|
||||
throws IOException {
|
||||
DocIdSetBuilder otherTerms = new DocIdSetBuilder(context.reader().maxDoc(), terms);
|
||||
PriorityQueue<PostingsEnum> highFrequencyTerms =
|
||||
|
@ -110,7 +112,7 @@ final class MultiTermQueryConstantScoreBlendedWrapper<Q extends MultiTermQuery>
|
|||
}
|
||||
} while (termsEnum.next() != null);
|
||||
|
||||
DisiPriorityQueue subs = new DisiPriorityQueue(highFrequencyTerms.size() + 1);
|
||||
List<DisiWrapper> subs = new ArrayList<>(highFrequencyTerms.size() + 1);
|
||||
for (DocIdSetIterator disi : highFrequencyTerms) {
|
||||
Scorer s = wrapWithDummyScorer(this, disi);
|
||||
subs.add(new DisiWrapper(s, false));
|
||||
|
@ -118,7 +120,7 @@ final class MultiTermQueryConstantScoreBlendedWrapper<Q extends MultiTermQuery>
|
|||
Scorer s = wrapWithDummyScorer(this, otherTerms.build().iterator());
|
||||
subs.add(new DisiWrapper(s, false));
|
||||
|
||||
return new WeightOrDocIdSetIterator(new DisjunctionDISIApproximation(subs));
|
||||
return new WeightOrDocIdSetIterator(new DisjunctionDISIApproximation(subs, leadCost));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
|
@ -49,7 +49,8 @@ final class MultiTermQueryConstantScoreWrapper<Q extends MultiTermQuery>
|
|||
int fieldDocCount,
|
||||
Terms terms,
|
||||
TermsEnum termsEnum,
|
||||
List<TermAndState> collectedTerms)
|
||||
List<TermAndState> collectedTerms,
|
||||
long leadCost)
|
||||
throws IOException {
|
||||
DocIdSetBuilder builder = new DocIdSetBuilder(context.reader().maxDoc(), terms);
|
||||
PostingsEnum docs = null;
|
||||
|
|
|
@ -357,17 +357,19 @@ public final class SynonymQuery extends Query {
|
|||
} else {
|
||||
|
||||
// we use termscorers + disjunction as an impl detail
|
||||
DisiPriorityQueue queue = new DisiPriorityQueue(iterators.size());
|
||||
List<DisiWrapper> wrappers = new ArrayList<>();
|
||||
for (int i = 0; i < iterators.size(); i++) {
|
||||
PostingsEnum postings = iterators.get(i);
|
||||
final TermScorer termScorer = new TermScorer(postings, simWeight, norms);
|
||||
float boost = termBoosts.get(i);
|
||||
final DisiWrapperFreq wrapper = new DisiWrapperFreq(termScorer, boost);
|
||||
queue.add(wrapper);
|
||||
wrappers.add(wrapper);
|
||||
}
|
||||
// Even though it is called approximation, it is accurate since none of
|
||||
// the sub iterators are two-phase iterators.
|
||||
DocIdSetIterator iterator = new DisjunctionDISIApproximation(queue);
|
||||
DisjunctionDISIApproximation disjunctionIterator =
|
||||
new DisjunctionDISIApproximation(wrappers, leadCost);
|
||||
DocIdSetIterator iterator = disjunctionIterator;
|
||||
|
||||
float[] boosts = new float[impacts.size()];
|
||||
for (int i = 0; i < boosts.length; i++) {
|
||||
|
@ -384,7 +386,7 @@ public final class SynonymQuery extends Query {
|
|||
iterator = impactsDisi;
|
||||
}
|
||||
|
||||
return new SynonymScorer(queue, iterator, impactsDisi, simWeight, norms);
|
||||
return new SynonymScorer(iterator, disjunctionIterator, impactsDisi, simWeight, norms);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -576,21 +578,21 @@ public final class SynonymQuery extends Query {
|
|||
|
||||
private static class SynonymScorer extends Scorer {
|
||||
|
||||
private final DisiPriorityQueue queue;
|
||||
private final DocIdSetIterator iterator;
|
||||
private final DisjunctionDISIApproximation disjunctionDisi;
|
||||
private final MaxScoreCache maxScoreCache;
|
||||
private final ImpactsDISI impactsDisi;
|
||||
private final SimScorer scorer;
|
||||
private final NumericDocValues norms;
|
||||
|
||||
SynonymScorer(
|
||||
DisiPriorityQueue queue,
|
||||
DocIdSetIterator iterator,
|
||||
DisjunctionDISIApproximation disjunctionDisi,
|
||||
ImpactsDISI impactsDisi,
|
||||
SimScorer scorer,
|
||||
NumericDocValues norms) {
|
||||
this.queue = queue;
|
||||
this.iterator = iterator;
|
||||
this.disjunctionDisi = disjunctionDisi;
|
||||
this.maxScoreCache = impactsDisi.getMaxScoreCache();
|
||||
this.impactsDisi = impactsDisi;
|
||||
this.scorer = scorer;
|
||||
|
@ -603,7 +605,7 @@ public final class SynonymQuery extends Query {
|
|||
}
|
||||
|
||||
float freq() throws IOException {
|
||||
DisiWrapperFreq w = (DisiWrapperFreq) queue.topList();
|
||||
DisiWrapperFreq w = (DisiWrapperFreq) disjunctionDisi.topList();
|
||||
float freq = w.freq();
|
||||
for (w = (DisiWrapperFreq) w.next; w != null; w = (DisiWrapperFreq) w.next) {
|
||||
freq += w.freq();
|
||||
|
|
|
@ -39,7 +39,6 @@ import org.apache.lucene.index.TermsEnum;
|
|||
import org.apache.lucene.search.BooleanClause;
|
||||
import org.apache.lucene.search.BooleanQuery;
|
||||
import org.apache.lucene.search.CollectionStatistics;
|
||||
import org.apache.lucene.search.DisiPriorityQueue;
|
||||
import org.apache.lucene.search.DisiWrapper;
|
||||
import org.apache.lucene.search.DisjunctionDISIApproximation;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
|
@ -383,6 +382,7 @@ public final class CombinedFieldQuery extends Query implements Accountable {
|
|||
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
|
||||
List<PostingsEnum> iterators = new ArrayList<>();
|
||||
List<FieldAndWeight> fields = new ArrayList<>();
|
||||
long cost = 0;
|
||||
for (int i = 0; i < fieldTerms.length; i++) {
|
||||
IOSupplier<TermState> supplier = termStates[i].get(context);
|
||||
TermState state = supplier == null ? null : supplier.get();
|
||||
|
@ -392,6 +392,7 @@ public final class CombinedFieldQuery extends Query implements Accountable {
|
|||
PostingsEnum postingsEnum = termsEnum.postings(null, PostingsEnum.FREQS);
|
||||
iterators.add(postingsEnum);
|
||||
fields.add(fieldAndWeights.get(fieldTerms[i].field()));
|
||||
cost += postingsEnum.cost();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -401,18 +402,31 @@ public final class CombinedFieldQuery extends Query implements Accountable {
|
|||
|
||||
MultiNormsLeafSimScorer scoringSimScorer =
|
||||
new MultiNormsLeafSimScorer(simWeight, context.reader(), fieldAndWeights.values(), true);
|
||||
// we use termscorers + disjunction as an impl detail
|
||||
DisiPriorityQueue queue = new DisiPriorityQueue(iterators.size());
|
||||
for (int i = 0; i < iterators.size(); i++) {
|
||||
float weight = fields.get(i).weight;
|
||||
queue.add(
|
||||
new WeightedDisiWrapper(new TermScorer(iterators.get(i), simWeight, null), weight));
|
||||
}
|
||||
// Even though it is called approximation, it is accurate since none of
|
||||
// the sub iterators are two-phase iterators.
|
||||
DocIdSetIterator iterator = new DisjunctionDISIApproximation(queue);
|
||||
final var scorer = new CombinedFieldScorer(queue, iterator, scoringSimScorer);
|
||||
return new DefaultScorerSupplier(scorer);
|
||||
|
||||
final long finalCost = cost;
|
||||
return new ScorerSupplier() {
|
||||
|
||||
@Override
|
||||
public Scorer get(long leadCost) throws IOException {
|
||||
// we use termscorers + disjunction as an impl detail
|
||||
List<DisiWrapper> wrappers = new ArrayList<>(iterators.size());
|
||||
for (int i = 0; i < iterators.size(); i++) {
|
||||
float weight = fields.get(i).weight;
|
||||
wrappers.add(
|
||||
new WeightedDisiWrapper(new TermScorer(iterators.get(i), simWeight, null), weight));
|
||||
}
|
||||
// Even though it is called approximation, it is accurate since none of
|
||||
// the sub iterators are two-phase iterators.
|
||||
DisjunctionDISIApproximation iterator =
|
||||
new DisjunctionDISIApproximation(wrappers, leadCost);
|
||||
return new CombinedFieldScorer(iterator, scoringSimScorer);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long cost() {
|
||||
return finalCost;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -437,14 +451,11 @@ public final class CombinedFieldQuery extends Query implements Accountable {
|
|||
}
|
||||
|
||||
private static class CombinedFieldScorer extends Scorer {
|
||||
private final DisiPriorityQueue queue;
|
||||
private final DocIdSetIterator iterator;
|
||||
private final DisjunctionDISIApproximation iterator;
|
||||
private final MultiNormsLeafSimScorer simScorer;
|
||||
private final float maxScore;
|
||||
|
||||
CombinedFieldScorer(
|
||||
DisiPriorityQueue queue, DocIdSetIterator iterator, MultiNormsLeafSimScorer simScorer) {
|
||||
this.queue = queue;
|
||||
CombinedFieldScorer(DisjunctionDISIApproximation iterator, MultiNormsLeafSimScorer simScorer) {
|
||||
this.iterator = iterator;
|
||||
this.simScorer = simScorer;
|
||||
this.maxScore = simScorer.getSimScorer().score(Float.POSITIVE_INFINITY, 1L);
|
||||
|
@ -456,7 +467,7 @@ public final class CombinedFieldQuery extends Query implements Accountable {
|
|||
}
|
||||
|
||||
float freq() throws IOException {
|
||||
DisiWrapper w = queue.topList();
|
||||
DisiWrapper w = iterator.topList();
|
||||
float freq = ((WeightedDisiWrapper) w).freq();
|
||||
for (w = w.next; w != null; w = w.next) {
|
||||
freq += ((WeightedDisiWrapper) w).freq();
|
||||
|
|
Loading…
Reference in New Issue