LUCENE-9071: Speed up BM25 scores. (#1043)

This commit is contained in:
Adrien Grand 2019-12-09 18:59:18 +01:00 committed by GitHub
parent df933f8104
commit c413656b62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 20 additions and 8 deletions

View File

@ -187,7 +187,7 @@ public class BM25Similarity extends Similarity {
float[] cache = new float[256]; float[] cache = new float[256];
for (int i = 0; i < cache.length; i++) { for (int i = 0; i < cache.length; i++) {
cache[i] = k1 * ((1 - b) + b * LENGTH_TABLE[i] / avgdl); cache[i] = 1f / (k1 * ((1 - b) + b * LENGTH_TABLE[i] / avgdl));
} }
return new BM25Scorer(boost, k1, b, idf, avgdl, cache); return new BM25Scorer(boost, k1, b, idf, avgdl, cache);
} }
@ -221,8 +221,17 @@ public class BM25Similarity extends Similarity {
@Override @Override
public float score(float freq, long encodedNorm) { public float score(float freq, long encodedNorm) {
double norm = cache[((byte) encodedNorm) & 0xFF]; // In order to guarantee monotonicity with both freq and norm without
return weight * (float) (freq / (freq + norm)); // promoting to doubles, we rewrite freq / (freq + norm) to
// 1 - 1 / (1 + freq * 1/norm).
// freq * 1/norm is guaranteed to be monotonic for both freq and norm due
// to the fact that multiplication and division round to the nearest
// float. And then monotonicity is preserved through composition via
// x -> 1 + x and x -> 1 - 1/x.
// Finally we expand weight * (1 - 1 / (1 + freq * 1/norm)) to
// weight - weight / (1 + freq * 1/norm), which runs slightly faster.
float normInverse = cache[((byte) encodedNorm) & 0xFF];
return weight - weight / (1f + freq * normInverse);
} }
@Override @Override
@ -230,8 +239,11 @@ public class BM25Similarity extends Similarity {
List<Explanation> subs = new ArrayList<>(explainConstantFactors()); List<Explanation> subs = new ArrayList<>(explainConstantFactors());
Explanation tfExpl = explainTF(freq, encodedNorm); Explanation tfExpl = explainTF(freq, encodedNorm);
subs.add(tfExpl); subs.add(tfExpl);
return Explanation.match(weight * tfExpl.getValue().floatValue(), float normInverse = cache[((byte) encodedNorm) & 0xFF];
"score(freq="+freq.getValue()+"), product of:", subs); // not using "product of" since the rewrite that we do in score()
// introduces a small rounding error that CheckHits complains about
return Explanation.match(weight - weight / (1f + freq.getValue().floatValue() * normInverse),
"score(freq="+freq.getValue()+"), computed as boost * idf * tf from:", subs);
} }
private Explanation explainTF(Explanation freq, long norm) { private Explanation explainTF(Explanation freq, long norm) {
@ -246,9 +258,9 @@ public class BM25Similarity extends Similarity {
subs.add(Explanation.match(doclen, "dl, length of field")); subs.add(Explanation.match(doclen, "dl, length of field"));
} }
subs.add(Explanation.match(avgdl, "avgdl, average length of field")); subs.add(Explanation.match(avgdl, "avgdl, average length of field"));
float normValue = k1 * ((1 - b) + b * doclen / avgdl); float normInverse = 1f / (k1 * ((1 - b) + b * doclen / avgdl));
return Explanation.match( return Explanation.match(
(float) (freq.getValue().floatValue() / (freq.getValue().floatValue() + (double) normValue)), 1f - 1f / (1 + freq.getValue().floatValue() * normInverse),
"tf, computed as freq / (freq + k1 * (1 - b + b * dl / avgdl)) from:", subs); "tf, computed as freq / (freq + k1 * (1 - b + b * dl / avgdl)) from:", subs);
} }