Enable rank-unsafe optimization of top-k hit computations by quantizing scores.

This adds a `ScoreQuantizingCollector`, which quantizes scores with a
configurable number of accuracy bits. This allows dynamic pruning to more
efficiently skip hits that would have similar scores. While this should be
considered rank-unsafe since top-hits are different compared to running the
top-score collector on its own, it's worth noting that top hits are correct in
quantized space.
This commit is contained in:
Adrien Grand 2023-10-06 18:36:19 +02:00
parent 2474940bff
commit 398c7e1f5c
2 changed files with 159 additions and 0 deletions

View File

@ -0,0 +1,103 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.misc.search;
import java.io.IOException;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.FilterCollector;
import org.apache.lucene.search.FilterLeafCollector;
import org.apache.lucene.search.FilterScorable;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.TopScoreDocCollector;
/**
* A {@link FilterCollector} that quantizes scores of the scorer in order to compute top-k hits more
* efficiently. This should generally be considered an unsafe approach to computing top-k hits, as
* the top hits would be different compared to without wrapping with this collector. However it is
* worth noting that top hits would be correct in quantized space.
*/
public final class ScoreQuantizingCollector extends FilterCollector {
private final int mask;
/**
* Sole constructor. The number of accuracy bits configures a trade-off between performance and
* how much accuracy is retained. Lower values retain less accuracy but make performance better.
* It is recommended to avoid passing values greater than 5 to actually observe speedups.
*
* @param in The collector to wrap, most likely a {@link TopScoreDocCollector}.
* @param accuracyBits How many bits of accuracy to retain, in [1,24). 24 is disallowed since this
* is the number of accuracy bits of single-precision floating point numbers, so this wrapper
* would not change scores.
*/
public ScoreQuantizingCollector(Collector in, int accuracyBits) {
super(in);
if (accuracyBits < 1 || accuracyBits >= 24) {
throw new IllegalArgumentException("accuracyBits must be in [1,24), got " + accuracyBits);
}
// floats have 23 mantissa bits
// we do -1 on the number of accuracy bits to account for the implicit bit
mask = ~0 << 23 - (accuracyBits - 1);
}
@Override
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
return new FilterLeafCollector(in.getLeafCollector(context)) {
@Override
public void setScorer(Scorable scorer) throws IOException {
in.setScorer(new QuantizingScorable(scorer));
}
};
}
private class QuantizingScorable extends FilterScorable {
QuantizingScorable(Scorable scorer) {
super(scorer);
}
@Override
public float score() throws IOException {
return roundDown(in.score());
}
@Override
public void setMinCompetitiveScore(float minScore) throws IOException {
in.setMinCompetitiveScore(roundUp(minScore));
}
}
float roundDown(float score) {
if (Float.isFinite(score) == false) {
return score;
}
int scoreBits = Float.floatToIntBits(score);
scoreBits &= mask;
return Float.intBitsToFloat(scoreBits);
}
float roundUp(float score) {
if (Float.isFinite(score) == false) {
return score;
}
int scoreBits = Float.floatToIntBits(score);
scoreBits = 1 + ((scoreBits - 1) | ~mask);
return Float.intBitsToFloat(scoreBits);
}
}

View File

@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.misc.search;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.TopScoreDocCollector;
import org.apache.lucene.tests.util.LuceneTestCase;
public class TestScoreQuantizingCollector extends LuceneTestCase {
public void testValidation() {
Collector collector = TopScoreDocCollector.create(10, 10);
expectThrows(IllegalArgumentException.class, () -> new ScoreQuantizingCollector(collector, 0));
expectThrows(IllegalArgumentException.class, () -> new ScoreQuantizingCollector(collector, 24));
for (int i = 1; i <= 23; ++i) {
new ScoreQuantizingCollector(collector, i); // no exception
}
}
public void testQuantize() {
ScoreQuantizingCollector collector =
new ScoreQuantizingCollector(TopScoreDocCollector.create(10, 10), 4);
assertEquals(1.125f, collector.roundDown(1.2345f), 0f);
assertEquals(1.25f, collector.roundUp(1.2345f), 0f);
assertEquals(1.25f, collector.roundDown(1.25f), 0f);
assertEquals(1.25f, collector.roundUp(1.25f), 0f);
assertEquals(0f, collector.roundDown(0f), 0f);
assertEquals(0f, collector.roundUp(0f), 0f);
assertEquals(0f, collector.roundDown(Float.MIN_VALUE), 0f);
assertEquals(0f, collector.roundUp(0f), 0f);
assertEquals(Float.MIN_NORMAL, collector.roundDown(Float.MIN_NORMAL), 0f);
assertEquals(Float.MIN_NORMAL, collector.roundUp(Float.MIN_NORMAL), 0f);
assertEquals(3.1901472E38f, collector.roundDown(Float.MAX_VALUE), 0f);
assertEquals(Float.POSITIVE_INFINITY, collector.roundUp(Float.MAX_VALUE), 0f);
}
}