LUCENE-8109: Propagate information about the minimum score in BooleanQuery.

This commit is contained in:
Adrien Grand 2017-12-29 13:24:23 +01:00
parent bc9836fd50
commit 33f421e798
6 changed files with 380 additions and 25 deletions

View File

@ -64,6 +64,11 @@ Optimizations
* LUCENE-7993: Phrase queries are now faster if total hit counts are not
required. (Adrien Grand)
* LUCENE-8109: Boolean queries propagate information about the minimum
competitive score in order to make collection faster if there are disjunctions
or phrase queries as sub queries, which know how to leverage this information
to run faster. (Adrien Grand)
======================= Lucene 7.3.0 =======================
API Changes

View File

@ -27,6 +27,7 @@ class ConjunctionScorer extends Scorer {
final DocIdSetIterator disi;
final Scorer[] scorers;
final Collection<Scorer> required;
final MaxScoreSumPropagator maxScorePropagator;
/** Create a new {@link ConjunctionScorer}, note that {@code scorers} must be a subset of {@code required}. */
ConjunctionScorer(Weight weight, Collection<Scorer> required, Collection<Scorer> scorers) {
@ -35,6 +36,7 @@ class ConjunctionScorer extends Scorer {
this.disi = ConjunctionDISI.intersectScorers(required);
this.scorers = scorers.toArray(new Scorer[scorers.size()]);
this.required = required;
this.maxScorePropagator = new MaxScoreSumPropagator(scorers);
}
@Override
@ -63,22 +65,13 @@ class ConjunctionScorer extends Scorer {
@Override
public float maxScore() {
// We iterate in the same order as #score() so no need to worry
// about floating-point errors: we would do the same errors in
// #score()
double sum = 0d;
for (Scorer scorer : scorers) {
sum += scorer.maxScore();
}
return (float) sum;
return maxScorePropagator.maxScore();
}
@Override
public void setMinCompetitiveScore(float score) {
if (scorers.length == 1) {
scorers[0].setMinCompetitiveScore(score);
}
// TODO: handle the case when there are multiple scoring clauses too
// Propagate to sub clauses.
maxScorePropagator.setMinCompetitiveScore(score);
}
@Override

View File

@ -0,0 +1,157 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.util.Collection;
import org.apache.lucene.util.InPlaceMergeSorter;
import org.apache.lucene.util.MathUtil;
/**
* Utility class to propagate scoring information in {@link BooleanQuery}, which
* compute the score as the sum of the scores of its matching clauses.
* This helps propagate information about the maximum produced score
*/
final class MaxScoreSumPropagator {
/**
* Return an array which, at index i, stores the sum of all entries of
* {@code v} except the one at index i.
*/
private static double[] computeSumOfComplement(float[] v) {
// We do not use subtraction on purpose because it would defeat the
// upperbound formula that we use for sums.
// Naive approach would be O(n^2), but we can do O(n) by computing the
// sum for i<j and i>j and then sum them.
double[] sum1 = new double[v.length];
for (int i = 1; i < sum1.length; ++i) {
sum1[i] = sum1[i-1] + v[i-1];
}
double[] sum2 = new double[v.length];
for (int i = sum2.length - 2; i >= 0; --i) {
sum2[i] = sum2[i+1] + v[i+1];
}
double[] result = new double[v.length];
for (int i = 0; i < result.length; ++i) {
result[i] = sum1[i] + sum2[i];
}
return result;
}
private final int numClauses;
private final float maxScore;
private final Scorer[] scorers;
private final double[] sumOfOtherMaxScores;
MaxScoreSumPropagator(Collection<? extends Scorer> scorerList) {
numClauses = scorerList.size();
scorers = scorerList.toArray(new Scorer[numClauses]);
// We'll need max scores multiple times so we cache them
float[] maxScores = new float[numClauses];
for (int i = 0; i < numClauses; ++i) {
maxScores[i] = scorers[i].maxScore();
}
// Sort by decreasing max score
new InPlaceMergeSorter() {
@Override
protected void swap(int i, int j) {
Scorer tmp = scorers[i];
scorers[i] = scorers[j];
scorers[j] = tmp;
float tmpF = maxScores[i];
maxScores[i] = maxScores[j];
maxScores[j] = tmpF;
}
@Override
protected int compare(int i, int j) {
return Float.compare(maxScores[j], maxScores[i]);
}
}.sort(0, scorers.length);
sumOfOtherMaxScores = computeSumOfComplement(maxScores);
if (numClauses == 0) {
maxScore = 0;
} else {
maxScore = sumUpperBound(maxScores[0] + sumOfOtherMaxScores[0]);
}
}
public float maxScore() {
return maxScore;
}
public void setMinCompetitiveScore(float minScoreSum) {
for (int i = 0; i < numClauses; ++i) {
double sumOfOtherMaxScores = this.sumOfOtherMaxScores[i];
float minCompetitiveScore = getMinCompetitiveScore(minScoreSum, sumOfOtherMaxScores);
if (minCompetitiveScore <= 0) {
// given that scorers are sorted by decreasing max score, next scorers will
// have 0 as a minimum competitive score too
break;
}
scorers[i].setMinCompetitiveScore(minCompetitiveScore);
}
}
/**
* Return the minimum score that a Scorer must produce in order for a hit to
* be competitive.
*/
private float getMinCompetitiveScore(float minScoreSum, double sumOfOtherMaxScores) {
assert numClauses > 0;
if (minScoreSum <= sumOfOtherMaxScores) {
return 0f;
}
// We need to find a value 'minScore' so that 'minScore + sumOfOtherMaxScores <= minScoreSum'
// TODO: is there an efficient way to find the greatest value that meets this requirement?
float minScore = (float) (minScoreSum - sumOfOtherMaxScores);
int iters = 0;
while (sumUpperBound(minScore + sumOfOtherMaxScores) > minScoreSum) {
// Important: use ulp of minScoreSum and not minScore to make sure that we
// converge quickly.
minScore -= Math.ulp(minScoreSum);
// this should converge in at most two iterations:
// - one because of the subtraction rounding error
// - one because of the error introduced by sumUpperBound
assert ++iters <= 2: iters;
}
return Math.max(minScore, 0f);
}
private float sumUpperBound(double sum) {
if (numClauses <= 2) {
// When there are only two clauses, the sum is always the same regardless
// of the order.
return (float) sum;
}
// The error of sums depends on the order in which values are summed up. In
// order to avoid this issue, we compute an upper bound of the value that
// the sum may take. If the max relative error is b, then it means that two
// sums are always within 2*b of each other.
// For conjunctions, we could skip this error factor since the order in which
// scores are summed up is predictable, but in practice, this wouldn't help
// much since the delta that is introduced by this error factor is usually
// cancelled by the float cast.
double b = MathUtil.sumRelativeErrorBound(numClauses);
return (float) ((1.0 + 2 * b) * sum);
}
}

View File

@ -18,6 +18,7 @@ package org.apache.lucene.search;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
/** A Scorer for queries with a required part and an optional part.
@ -35,6 +36,7 @@ class ReqOptSumScorer extends Scorer {
private boolean optIsRequired;
private final DocIdSetIterator approximation;
private final TwoPhaseIterator twoPhase;
final MaxScoreSumPropagator maxScorePropagator;
/** Construct a <code>ReqOptScorer</code>.
* @param reqScorer The required scorer. This must match.
@ -51,6 +53,7 @@ class ReqOptSumScorer extends Scorer {
this.optScorer = optScorer;
this.reqMaxScore = reqScorer.maxScore();
this.maxScorePropagator = new MaxScoreSumPropagator(Arrays.asList(reqScorer, optScorer));
final TwoPhaseIterator reqTwoPhase = reqScorer.twoPhaseIterator();
this.optTwoPhase = optScorer.twoPhaseIterator();
@ -208,14 +211,17 @@ class ReqOptSumScorer extends Scorer {
@Override
public float maxScore() {
return reqScorer.maxScore() + optScorer.maxScore();
return maxScorePropagator.maxScore();
}
@Override
public void setMinCompetitiveScore(float minScore) {
// Potentially move to a conjunction
if (optIsRequired == false && minScore > reqMaxScore) {
optIsRequired = true;
}
// And also propagate to sub clauses.
maxScorePropagator.setMinCompetitiveScore(minScore);
}
@Override

View File

@ -26,8 +26,6 @@ import java.util.Collection;
import java.util.List;
import java.util.OptionalInt;
import org.apache.lucene.util.MathUtil;
/**
* This implements the WAND (Weak AND) algorithm for dynamic pruning
* described in "Efficient Query Evaluation using a Two-Level Retrieval
@ -122,7 +120,7 @@ final class WANDScorer extends Scorer {
int tailSize;
final long cost;
final float maxScore;
final MaxScoreSumPropagator maxScorePropagator;
WANDScorer(Weight weight, Collection<Scorer> scorers) {
super(weight);
@ -145,12 +143,10 @@ final class WANDScorer extends Scorer {
// Use a scaling factor of 0 if all max scores are either 0 or +Infty
this.scalingFactor = scalingFactor.orElse(0);
double maxScoreSum = 0;
for (Scorer scorer : scorers) {
DisiWrapper w = new DisiWrapper(scorer);
float maxScore = scorer.maxScore();
w.maxScore = scaleMaxScore(maxScore, this.scalingFactor);
maxScoreSum += maxScore;
addLead(w);
}
@ -159,12 +155,7 @@ final class WANDScorer extends Scorer {
cost += w.cost;
}
this.cost = cost;
// The error of sums depends on the order in which values are summed up. In
// order to avoid this issue, we compute an upper bound of the value that
// the sum may take. If the max relative error is b, then it means that two
// sums are always within 2*b of each other.
double maxScoreRelativeErrorBound = MathUtil.sumRelativeErrorBound(scorers.size());
this.maxScore = (float) ((1.0 + 2 * maxScoreRelativeErrorBound) * maxScoreSum);
this.maxScorePropagator = new MaxScoreSumPropagator(scorers);
}
// returns a boolean so that it can be called from assert
@ -195,10 +186,15 @@ final class WANDScorer extends Scorer {
@Override
public void setMinCompetitiveScore(float minScore) {
// Let this disjunction know about the new min score so that it can skip
// over clauses that produce low scores.
assert minScore >= 0;
long scaledMinScore = scaleMinScore(minScore, scalingFactor);
assert scaledMinScore >= minCompetitiveScore;
minCompetitiveScore = scaledMinScore;
// And also propagate to sub clauses.
maxScorePropagator.setMinCompetitiveScore(minScore);
}
@Override
@ -386,7 +382,7 @@ final class WANDScorer extends Scorer {
@Override
public float maxScore() {
return maxScore;
return maxScorePropagator.maxScore();
}
@Override

View File

@ -0,0 +1,198 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.TestUtil;
public class TestMaxScoreSumPropagator extends LuceneTestCase {
private static class FakeScorer extends Scorer {
final float maxScore;
float minCompetitiveScore;
FakeScorer(float maxScore) {
super(null);
this.maxScore = maxScore;
}
@Override
public int docID() {
throw new UnsupportedOperationException();
}
@Override
public float score() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public DocIdSetIterator iterator() {
throw new UnsupportedOperationException();
}
@Override
public float maxScore() {
return maxScore;
}
@Override
public void setMinCompetitiveScore(float minCompetitiveScore) {
this.minCompetitiveScore = minCompetitiveScore;
}
}
public void test0Clause() {
MaxScoreSumPropagator p = new MaxScoreSumPropagator(Collections.emptyList());
assertEquals(0f, p.maxScore(), 0f);
p.setMinCompetitiveScore(0f); // no exception
p.setMinCompetitiveScore(0.5f); // no exception
}
public void test1Clause() {
FakeScorer a = new FakeScorer(1);
MaxScoreSumPropagator p = new MaxScoreSumPropagator(Collections.singletonList(a));
assertEquals(1f, p.maxScore(), 0f);
p.setMinCompetitiveScore(0f);
assertEquals(0f, a.minCompetitiveScore, 0f);
p.setMinCompetitiveScore(0.5f);
assertEquals(0.5f, a.minCompetitiveScore, 0f);
p.setMinCompetitiveScore(1f);
assertEquals(1f, a.minCompetitiveScore, 0f);
}
public void test2Clauses() {
FakeScorer a = new FakeScorer(1);
FakeScorer b = new FakeScorer(2);
MaxScoreSumPropagator p = new MaxScoreSumPropagator(Arrays.asList(a, b));
assertEquals(3f, p.maxScore(), 0f);
p.setMinCompetitiveScore(1f);
assertEquals(0f, a.minCompetitiveScore, 0f);
assertEquals(0f, b.minCompetitiveScore, 0f);
p.setMinCompetitiveScore(2f);
assertEquals(0f, a.minCompetitiveScore, 0f);
assertEquals(1f, b.minCompetitiveScore, 0f);
p.setMinCompetitiveScore(2.5f);
assertEquals(0.5f, a.minCompetitiveScore, 0f);
assertEquals(1.5f, b.minCompetitiveScore, 0f);
p.setMinCompetitiveScore(3f);
assertEquals(1f, a.minCompetitiveScore, 0f);
assertEquals(2f, b.minCompetitiveScore, 0f);
}
public void test3Clauses() {
FakeScorer a = new FakeScorer(1);
FakeScorer b = new FakeScorer(2);
FakeScorer c = new FakeScorer(1.5f);
MaxScoreSumPropagator p = new MaxScoreSumPropagator(Arrays.asList(a, b, c));
assertEquals(4.5f, p.maxScore(), 0f);
p.setMinCompetitiveScore(1f);
assertEquals(0f, a.minCompetitiveScore, 0f);
assertEquals(0f, b.minCompetitiveScore, 0f);
assertEquals(0f, c.minCompetitiveScore, 0f);
p.setMinCompetitiveScore(2f);
assertEquals(0f, a.minCompetitiveScore, 0f);
assertEquals(0f, b.minCompetitiveScore, 0f);
assertEquals(0f, c.minCompetitiveScore, 0f);
p.setMinCompetitiveScore(3f);
assertEquals(0f, a.minCompetitiveScore, 0f);
assertEquals(0.5f, b.minCompetitiveScore, 0f);
assertEquals(0f, c.minCompetitiveScore, 0f);
p.setMinCompetitiveScore(4f);
assertEquals(0.5f, a.minCompetitiveScore, 0f);
assertEquals(1.5f, b.minCompetitiveScore, 0f);
assertEquals(1f, c.minCompetitiveScore, 0f);
}
public void test2ClausesRandomScore() {
for (int iter = 0; iter < 10; ++iter) {
FakeScorer a = new FakeScorer(random().nextFloat());
FakeScorer b = new FakeScorer(Math.nextUp(a.maxScore()) + random().nextFloat());
MaxScoreSumPropagator p = new MaxScoreSumPropagator(Arrays.asList(a, b));
assertEquals(a.maxScore() + b.maxScore(), p.maxScore(), 0f);
assertMinCompetitiveScore(Arrays.asList(a, b), p, Math.nextUp(a.maxScore()));
assertMinCompetitiveScore(Arrays.asList(a, b), p, (a.maxScore() + b.maxScore()) / 2);
assertMinCompetitiveScore(Arrays.asList(a, b), p, Math.nextDown(a.maxScore() + b.maxScore()));
assertMinCompetitiveScore(Arrays.asList(a, b), p, a.maxScore() + b.maxScore());
}
}
public void testNClausesRandomScore() {
for (int iter = 0; iter < 100; ++iter) {
List<FakeScorer> scorers = new ArrayList<>();
int numScorers = TestUtil.nextInt(random(), 3, 4 << random().nextInt(8));
double sumOfMaxScore = 0;
for (int i = 0; i < numScorers; ++i) {
float maxScore = random().nextFloat();
scorers.add(new FakeScorer(maxScore));
sumOfMaxScore += maxScore;
}
MaxScoreSumPropagator p = new MaxScoreSumPropagator(scorers);
assertTrue(p.maxScore() >= (float) sumOfMaxScore);
for (int i = 0; i < 10; ++i) {
final float minCompetitiveScore = random().nextFloat() * numScorers;
assertMinCompetitiveScore(scorers, p, minCompetitiveScore);
// reset
for (FakeScorer scorer : scorers) {
scorer.minCompetitiveScore = 0;
}
}
}
}
private void assertMinCompetitiveScore(Collection<FakeScorer> scorers, MaxScoreSumPropagator p, float minCompetitiveScore) {
p.setMinCompetitiveScore(minCompetitiveScore);
for (FakeScorer scorer : scorers) {
if (scorer.minCompetitiveScore == 0f) {
// no propagation is performed, still visiting all hits
break;
}
double scoreSum = scorer.minCompetitiveScore;
for (FakeScorer scorer2 : scorers) {
if (scorer2 != scorer) {
scoreSum += scorer2.maxScore();
}
}
assertTrue(
"scoreSum=" + scoreSum + ", minCompetitiveScore=" + minCompetitiveScore,
(float) scoreSum <= minCompetitiveScore);
}
}
}