LUCENE-8109: Propagate information about the minimum score in BooleanQuery.

This commit is contained in:
Adrien Grand 2017-12-29 13:24:23 +01:00
parent bc9836fd50
commit 33f421e798
6 changed files with 380 additions and 25 deletions

View File

@ -64,6 +64,11 @@ Optimizations
* LUCENE-7993: Phrase queries are now faster if total hit counts are not * LUCENE-7993: Phrase queries are now faster if total hit counts are not
required. (Adrien Grand) required. (Adrien Grand)
* LUCENE-8109: Boolean queries propagate information about the minimum
competitive score in order to make collection faster if there are disjunctions
or phrase queries as sub queries, which know how to leverage this information
to run faster. (Adrien Grand)
======================= Lucene 7.3.0 ======================= ======================= Lucene 7.3.0 =======================
API Changes API Changes

View File

@ -27,6 +27,7 @@ class ConjunctionScorer extends Scorer {
final DocIdSetIterator disi; final DocIdSetIterator disi;
final Scorer[] scorers; final Scorer[] scorers;
final Collection<Scorer> required; final Collection<Scorer> required;
final MaxScoreSumPropagator maxScorePropagator;
/** Create a new {@link ConjunctionScorer}, note that {@code scorers} must be a subset of {@code required}. */ /** Create a new {@link ConjunctionScorer}, note that {@code scorers} must be a subset of {@code required}. */
ConjunctionScorer(Weight weight, Collection<Scorer> required, Collection<Scorer> scorers) { ConjunctionScorer(Weight weight, Collection<Scorer> required, Collection<Scorer> scorers) {
@ -35,6 +36,7 @@ class ConjunctionScorer extends Scorer {
this.disi = ConjunctionDISI.intersectScorers(required); this.disi = ConjunctionDISI.intersectScorers(required);
this.scorers = scorers.toArray(new Scorer[scorers.size()]); this.scorers = scorers.toArray(new Scorer[scorers.size()]);
this.required = required; this.required = required;
this.maxScorePropagator = new MaxScoreSumPropagator(scorers);
} }
@Override @Override
@ -63,22 +65,13 @@ class ConjunctionScorer extends Scorer {
@Override @Override
public float maxScore() { public float maxScore() {
// We iterate in the same order as #score() so no need to worry return maxScorePropagator.maxScore();
// about floating-point errors: we would do the same errors in
// #score()
double sum = 0d;
for (Scorer scorer : scorers) {
sum += scorer.maxScore();
}
return (float) sum;
} }
@Override @Override
public void setMinCompetitiveScore(float score) { public void setMinCompetitiveScore(float score) {
if (scorers.length == 1) { // Propagate to sub clauses.
scorers[0].setMinCompetitiveScore(score); maxScorePropagator.setMinCompetitiveScore(score);
}
// TODO: handle the case when there are multiple scoring clauses too
} }
@Override @Override

View File

@ -0,0 +1,157 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.util.Collection;
import org.apache.lucene.util.InPlaceMergeSorter;
import org.apache.lucene.util.MathUtil;
/**
* Utility class to propagate scoring information in {@link BooleanQuery}, which
* compute the score as the sum of the scores of its matching clauses.
* This helps propagate information about the maximum produced score
*/
final class MaxScoreSumPropagator {
/**
* Return an array which, at index i, stores the sum of all entries of
* {@code v} except the one at index i.
*/
private static double[] computeSumOfComplement(float[] v) {
// We do not use subtraction on purpose because it would defeat the
// upperbound formula that we use for sums.
// Naive approach would be O(n^2), but we can do O(n) by computing the
// sum for i<j and i>j and then sum them.
double[] sum1 = new double[v.length];
for (int i = 1; i < sum1.length; ++i) {
sum1[i] = sum1[i-1] + v[i-1];
}
double[] sum2 = new double[v.length];
for (int i = sum2.length - 2; i >= 0; --i) {
sum2[i] = sum2[i+1] + v[i+1];
}
double[] result = new double[v.length];
for (int i = 0; i < result.length; ++i) {
result[i] = sum1[i] + sum2[i];
}
return result;
}
private final int numClauses;
private final float maxScore;
private final Scorer[] scorers;
private final double[] sumOfOtherMaxScores;
MaxScoreSumPropagator(Collection<? extends Scorer> scorerList) {
numClauses = scorerList.size();
scorers = scorerList.toArray(new Scorer[numClauses]);
// We'll need max scores multiple times so we cache them
float[] maxScores = new float[numClauses];
for (int i = 0; i < numClauses; ++i) {
maxScores[i] = scorers[i].maxScore();
}
// Sort by decreasing max score
new InPlaceMergeSorter() {
@Override
protected void swap(int i, int j) {
Scorer tmp = scorers[i];
scorers[i] = scorers[j];
scorers[j] = tmp;
float tmpF = maxScores[i];
maxScores[i] = maxScores[j];
maxScores[j] = tmpF;
}
@Override
protected int compare(int i, int j) {
return Float.compare(maxScores[j], maxScores[i]);
}
}.sort(0, scorers.length);
sumOfOtherMaxScores = computeSumOfComplement(maxScores);
if (numClauses == 0) {
maxScore = 0;
} else {
maxScore = sumUpperBound(maxScores[0] + sumOfOtherMaxScores[0]);
}
}
public float maxScore() {
return maxScore;
}
public void setMinCompetitiveScore(float minScoreSum) {
for (int i = 0; i < numClauses; ++i) {
double sumOfOtherMaxScores = this.sumOfOtherMaxScores[i];
float minCompetitiveScore = getMinCompetitiveScore(minScoreSum, sumOfOtherMaxScores);
if (minCompetitiveScore <= 0) {
// given that scorers are sorted by decreasing max score, next scorers will
// have 0 as a minimum competitive score too
break;
}
scorers[i].setMinCompetitiveScore(minCompetitiveScore);
}
}
/**
* Return the minimum score that a Scorer must produce in order for a hit to
* be competitive.
*/
private float getMinCompetitiveScore(float minScoreSum, double sumOfOtherMaxScores) {
assert numClauses > 0;
if (minScoreSum <= sumOfOtherMaxScores) {
return 0f;
}
// We need to find a value 'minScore' so that 'minScore + sumOfOtherMaxScores <= minScoreSum'
// TODO: is there an efficient way to find the greatest value that meets this requirement?
float minScore = (float) (minScoreSum - sumOfOtherMaxScores);
int iters = 0;
while (sumUpperBound(minScore + sumOfOtherMaxScores) > minScoreSum) {
// Important: use ulp of minScoreSum and not minScore to make sure that we
// converge quickly.
minScore -= Math.ulp(minScoreSum);
// this should converge in at most two iterations:
// - one because of the subtraction rounding error
// - one because of the error introduced by sumUpperBound
assert ++iters <= 2: iters;
}
return Math.max(minScore, 0f);
}
private float sumUpperBound(double sum) {
if (numClauses <= 2) {
// When there are only two clauses, the sum is always the same regardless
// of the order.
return (float) sum;
}
// The error of sums depends on the order in which values are summed up. In
// order to avoid this issue, we compute an upper bound of the value that
// the sum may take. If the max relative error is b, then it means that two
// sums are always within 2*b of each other.
// For conjunctions, we could skip this error factor since the order in which
// scores are summed up is predictable, but in practice, this wouldn't help
// much since the delta that is introduced by this error factor is usually
// cancelled by the float cast.
double b = MathUtil.sumRelativeErrorBound(numClauses);
return (float) ((1.0 + 2 * b) * sum);
}
}

View File

@ -18,6 +18,7 @@ package org.apache.lucene.search;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
/** A Scorer for queries with a required part and an optional part. /** A Scorer for queries with a required part and an optional part.
@ -35,6 +36,7 @@ class ReqOptSumScorer extends Scorer {
private boolean optIsRequired; private boolean optIsRequired;
private final DocIdSetIterator approximation; private final DocIdSetIterator approximation;
private final TwoPhaseIterator twoPhase; private final TwoPhaseIterator twoPhase;
final MaxScoreSumPropagator maxScorePropagator;
/** Construct a <code>ReqOptScorer</code>. /** Construct a <code>ReqOptScorer</code>.
* @param reqScorer The required scorer. This must match. * @param reqScorer The required scorer. This must match.
@ -51,6 +53,7 @@ class ReqOptSumScorer extends Scorer {
this.optScorer = optScorer; this.optScorer = optScorer;
this.reqMaxScore = reqScorer.maxScore(); this.reqMaxScore = reqScorer.maxScore();
this.maxScorePropagator = new MaxScoreSumPropagator(Arrays.asList(reqScorer, optScorer));
final TwoPhaseIterator reqTwoPhase = reqScorer.twoPhaseIterator(); final TwoPhaseIterator reqTwoPhase = reqScorer.twoPhaseIterator();
this.optTwoPhase = optScorer.twoPhaseIterator(); this.optTwoPhase = optScorer.twoPhaseIterator();
@ -208,14 +211,17 @@ class ReqOptSumScorer extends Scorer {
@Override @Override
public float maxScore() { public float maxScore() {
return reqScorer.maxScore() + optScorer.maxScore(); return maxScorePropagator.maxScore();
} }
@Override @Override
public void setMinCompetitiveScore(float minScore) { public void setMinCompetitiveScore(float minScore) {
// Potentially move to a conjunction
if (optIsRequired == false && minScore > reqMaxScore) { if (optIsRequired == false && minScore > reqMaxScore) {
optIsRequired = true; optIsRequired = true;
} }
// And also propagate to sub clauses.
maxScorePropagator.setMinCompetitiveScore(minScore);
} }
@Override @Override

View File

@ -26,8 +26,6 @@ import java.util.Collection;
import java.util.List; import java.util.List;
import java.util.OptionalInt; import java.util.OptionalInt;
import org.apache.lucene.util.MathUtil;
/** /**
* This implements the WAND (Weak AND) algorithm for dynamic pruning * This implements the WAND (Weak AND) algorithm for dynamic pruning
* described in "Efficient Query Evaluation using a Two-Level Retrieval * described in "Efficient Query Evaluation using a Two-Level Retrieval
@ -122,7 +120,7 @@ final class WANDScorer extends Scorer {
int tailSize; int tailSize;
final long cost; final long cost;
final float maxScore; final MaxScoreSumPropagator maxScorePropagator;
WANDScorer(Weight weight, Collection<Scorer> scorers) { WANDScorer(Weight weight, Collection<Scorer> scorers) {
super(weight); super(weight);
@ -145,12 +143,10 @@ final class WANDScorer extends Scorer {
// Use a scaling factor of 0 if all max scores are either 0 or +Infty // Use a scaling factor of 0 if all max scores are either 0 or +Infty
this.scalingFactor = scalingFactor.orElse(0); this.scalingFactor = scalingFactor.orElse(0);
double maxScoreSum = 0;
for (Scorer scorer : scorers) { for (Scorer scorer : scorers) {
DisiWrapper w = new DisiWrapper(scorer); DisiWrapper w = new DisiWrapper(scorer);
float maxScore = scorer.maxScore(); float maxScore = scorer.maxScore();
w.maxScore = scaleMaxScore(maxScore, this.scalingFactor); w.maxScore = scaleMaxScore(maxScore, this.scalingFactor);
maxScoreSum += maxScore;
addLead(w); addLead(w);
} }
@ -159,12 +155,7 @@ final class WANDScorer extends Scorer {
cost += w.cost; cost += w.cost;
} }
this.cost = cost; this.cost = cost;
// The error of sums depends on the order in which values are summed up. In this.maxScorePropagator = new MaxScoreSumPropagator(scorers);
// order to avoid this issue, we compute an upper bound of the value that
// the sum may take. If the max relative error is b, then it means that two
// sums are always within 2*b of each other.
double maxScoreRelativeErrorBound = MathUtil.sumRelativeErrorBound(scorers.size());
this.maxScore = (float) ((1.0 + 2 * maxScoreRelativeErrorBound) * maxScoreSum);
} }
// returns a boolean so that it can be called from assert // returns a boolean so that it can be called from assert
@ -195,10 +186,15 @@ final class WANDScorer extends Scorer {
@Override @Override
public void setMinCompetitiveScore(float minScore) { public void setMinCompetitiveScore(float minScore) {
// Let this disjunction know about the new min score so that it can skip
// over clauses that produce low scores.
assert minScore >= 0; assert minScore >= 0;
long scaledMinScore = scaleMinScore(minScore, scalingFactor); long scaledMinScore = scaleMinScore(minScore, scalingFactor);
assert scaledMinScore >= minCompetitiveScore; assert scaledMinScore >= minCompetitiveScore;
minCompetitiveScore = scaledMinScore; minCompetitiveScore = scaledMinScore;
// And also propagate to sub clauses.
maxScorePropagator.setMinCompetitiveScore(minScore);
} }
@Override @Override
@ -386,7 +382,7 @@ final class WANDScorer extends Scorer {
@Override @Override
public float maxScore() { public float maxScore() {
return maxScore; return maxScorePropagator.maxScore();
} }
@Override @Override

View File

@ -0,0 +1,198 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.TestUtil;
public class TestMaxScoreSumPropagator extends LuceneTestCase {
private static class FakeScorer extends Scorer {
final float maxScore;
float minCompetitiveScore;
FakeScorer(float maxScore) {
super(null);
this.maxScore = maxScore;
}
@Override
public int docID() {
throw new UnsupportedOperationException();
}
@Override
public float score() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public DocIdSetIterator iterator() {
throw new UnsupportedOperationException();
}
@Override
public float maxScore() {
return maxScore;
}
@Override
public void setMinCompetitiveScore(float minCompetitiveScore) {
this.minCompetitiveScore = minCompetitiveScore;
}
}
public void test0Clause() {
MaxScoreSumPropagator p = new MaxScoreSumPropagator(Collections.emptyList());
assertEquals(0f, p.maxScore(), 0f);
p.setMinCompetitiveScore(0f); // no exception
p.setMinCompetitiveScore(0.5f); // no exception
}
public void test1Clause() {
FakeScorer a = new FakeScorer(1);
MaxScoreSumPropagator p = new MaxScoreSumPropagator(Collections.singletonList(a));
assertEquals(1f, p.maxScore(), 0f);
p.setMinCompetitiveScore(0f);
assertEquals(0f, a.minCompetitiveScore, 0f);
p.setMinCompetitiveScore(0.5f);
assertEquals(0.5f, a.minCompetitiveScore, 0f);
p.setMinCompetitiveScore(1f);
assertEquals(1f, a.minCompetitiveScore, 0f);
}
public void test2Clauses() {
FakeScorer a = new FakeScorer(1);
FakeScorer b = new FakeScorer(2);
MaxScoreSumPropagator p = new MaxScoreSumPropagator(Arrays.asList(a, b));
assertEquals(3f, p.maxScore(), 0f);
p.setMinCompetitiveScore(1f);
assertEquals(0f, a.minCompetitiveScore, 0f);
assertEquals(0f, b.minCompetitiveScore, 0f);
p.setMinCompetitiveScore(2f);
assertEquals(0f, a.minCompetitiveScore, 0f);
assertEquals(1f, b.minCompetitiveScore, 0f);
p.setMinCompetitiveScore(2.5f);
assertEquals(0.5f, a.minCompetitiveScore, 0f);
assertEquals(1.5f, b.minCompetitiveScore, 0f);
p.setMinCompetitiveScore(3f);
assertEquals(1f, a.minCompetitiveScore, 0f);
assertEquals(2f, b.minCompetitiveScore, 0f);
}
public void test3Clauses() {
FakeScorer a = new FakeScorer(1);
FakeScorer b = new FakeScorer(2);
FakeScorer c = new FakeScorer(1.5f);
MaxScoreSumPropagator p = new MaxScoreSumPropagator(Arrays.asList(a, b, c));
assertEquals(4.5f, p.maxScore(), 0f);
p.setMinCompetitiveScore(1f);
assertEquals(0f, a.minCompetitiveScore, 0f);
assertEquals(0f, b.minCompetitiveScore, 0f);
assertEquals(0f, c.minCompetitiveScore, 0f);
p.setMinCompetitiveScore(2f);
assertEquals(0f, a.minCompetitiveScore, 0f);
assertEquals(0f, b.minCompetitiveScore, 0f);
assertEquals(0f, c.minCompetitiveScore, 0f);
p.setMinCompetitiveScore(3f);
assertEquals(0f, a.minCompetitiveScore, 0f);
assertEquals(0.5f, b.minCompetitiveScore, 0f);
assertEquals(0f, c.minCompetitiveScore, 0f);
p.setMinCompetitiveScore(4f);
assertEquals(0.5f, a.minCompetitiveScore, 0f);
assertEquals(1.5f, b.minCompetitiveScore, 0f);
assertEquals(1f, c.minCompetitiveScore, 0f);
}
public void test2ClausesRandomScore() {
for (int iter = 0; iter < 10; ++iter) {
FakeScorer a = new FakeScorer(random().nextFloat());
FakeScorer b = new FakeScorer(Math.nextUp(a.maxScore()) + random().nextFloat());
MaxScoreSumPropagator p = new MaxScoreSumPropagator(Arrays.asList(a, b));
assertEquals(a.maxScore() + b.maxScore(), p.maxScore(), 0f);
assertMinCompetitiveScore(Arrays.asList(a, b), p, Math.nextUp(a.maxScore()));
assertMinCompetitiveScore(Arrays.asList(a, b), p, (a.maxScore() + b.maxScore()) / 2);
assertMinCompetitiveScore(Arrays.asList(a, b), p, Math.nextDown(a.maxScore() + b.maxScore()));
assertMinCompetitiveScore(Arrays.asList(a, b), p, a.maxScore() + b.maxScore());
}
}
public void testNClausesRandomScore() {
for (int iter = 0; iter < 100; ++iter) {
List<FakeScorer> scorers = new ArrayList<>();
int numScorers = TestUtil.nextInt(random(), 3, 4 << random().nextInt(8));
double sumOfMaxScore = 0;
for (int i = 0; i < numScorers; ++i) {
float maxScore = random().nextFloat();
scorers.add(new FakeScorer(maxScore));
sumOfMaxScore += maxScore;
}
MaxScoreSumPropagator p = new MaxScoreSumPropagator(scorers);
assertTrue(p.maxScore() >= (float) sumOfMaxScore);
for (int i = 0; i < 10; ++i) {
final float minCompetitiveScore = random().nextFloat() * numScorers;
assertMinCompetitiveScore(scorers, p, minCompetitiveScore);
// reset
for (FakeScorer scorer : scorers) {
scorer.minCompetitiveScore = 0;
}
}
}
}
private void assertMinCompetitiveScore(Collection<FakeScorer> scorers, MaxScoreSumPropagator p, float minCompetitiveScore) {
p.setMinCompetitiveScore(minCompetitiveScore);
for (FakeScorer scorer : scorers) {
if (scorer.minCompetitiveScore == 0f) {
// no propagation is performed, still visiting all hits
break;
}
double scoreSum = scorer.minCompetitiveScore;
for (FakeScorer scorer2 : scorers) {
if (scorer2 != scorer) {
scoreSum += scorer2.maxScore();
}
}
assertTrue(
"scoreSum=" + scoreSum + ", minCompetitiveScore=" + minCompetitiveScore,
(float) scoreSum <= minCompetitiveScore);
}
}
}