mirror of https://github.com/apache/lucene.git
Enable rank-unsafe optimization of top-k hit computations by quantizing scores.
This adds a `ScoreQuantizingCollector`, which quantizes scores with a configurable number of accuracy bits. This allows dynamic pruning to more efficiently skip hits that would have similar scores. While this should be considered rank-unsafe since top-hits are different compared to running the top-score collector on its own, it's worth noting that top hits are correct in quantized space.
This commit is contained in:
parent
2474940bff
commit
398c7e1f5c
|
@ -0,0 +1,103 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.misc.search;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.search.Collector;
|
||||
import org.apache.lucene.search.FilterCollector;
|
||||
import org.apache.lucene.search.FilterLeafCollector;
|
||||
import org.apache.lucene.search.FilterScorable;
|
||||
import org.apache.lucene.search.LeafCollector;
|
||||
import org.apache.lucene.search.Scorable;
|
||||
import org.apache.lucene.search.TopScoreDocCollector;
|
||||
|
||||
/**
|
||||
* A {@link FilterCollector} that quantizes scores of the scorer in order to compute top-k hits more
|
||||
* efficiently. This should generally be considered an unsafe approach to computing top-k hits, as
|
||||
* the top hits would be different compared to without wrapping with this collector. However it is
|
||||
* worth noting that top hits would be correct in quantized space.
|
||||
*/
|
||||
public final class ScoreQuantizingCollector extends FilterCollector {
|
||||
|
||||
private final int mask;
|
||||
|
||||
/**
|
||||
* Sole constructor. The number of accuracy bits configures a trade-off between performance and
|
||||
* how much accuracy is retained. Lower values retain less accuracy but make performance better.
|
||||
* It is recommended to avoid passing values greater than 5 to actually observe speedups.
|
||||
*
|
||||
* @param in The collector to wrap, most likely a {@link TopScoreDocCollector}.
|
||||
* @param accuracyBits How many bits of accuracy to retain, in [1,24). 24 is disallowed since this
|
||||
* is the number of accuracy bits of single-precision floating point numbers, so this wrapper
|
||||
* would not change scores.
|
||||
*/
|
||||
public ScoreQuantizingCollector(Collector in, int accuracyBits) {
|
||||
super(in);
|
||||
if (accuracyBits < 1 || accuracyBits >= 24) {
|
||||
throw new IllegalArgumentException("accuracyBits must be in [1,24), got " + accuracyBits);
|
||||
}
|
||||
// floats have 23 mantissa bits
|
||||
// we do -1 on the number of accuracy bits to account for the implicit bit
|
||||
mask = ~0 << 23 - (accuracyBits - 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
|
||||
return new FilterLeafCollector(in.getLeafCollector(context)) {
|
||||
@Override
|
||||
public void setScorer(Scorable scorer) throws IOException {
|
||||
in.setScorer(new QuantizingScorable(scorer));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
private class QuantizingScorable extends FilterScorable {
|
||||
|
||||
QuantizingScorable(Scorable scorer) {
|
||||
super(scorer);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return roundDown(in.score());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setMinCompetitiveScore(float minScore) throws IOException {
|
||||
in.setMinCompetitiveScore(roundUp(minScore));
|
||||
}
|
||||
}
|
||||
|
||||
float roundDown(float score) {
|
||||
if (Float.isFinite(score) == false) {
|
||||
return score;
|
||||
}
|
||||
int scoreBits = Float.floatToIntBits(score);
|
||||
scoreBits &= mask;
|
||||
return Float.intBitsToFloat(scoreBits);
|
||||
}
|
||||
|
||||
float roundUp(float score) {
|
||||
if (Float.isFinite(score) == false) {
|
||||
return score;
|
||||
}
|
||||
int scoreBits = Float.floatToIntBits(score);
|
||||
scoreBits = 1 + ((scoreBits - 1) | ~mask);
|
||||
return Float.intBitsToFloat(scoreBits);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,56 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.misc.search;
|
||||
|
||||
import org.apache.lucene.search.Collector;
|
||||
import org.apache.lucene.search.TopScoreDocCollector;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
|
||||
public class TestScoreQuantizingCollector extends LuceneTestCase {
|
||||
|
||||
public void testValidation() {
|
||||
Collector collector = TopScoreDocCollector.create(10, 10);
|
||||
expectThrows(IllegalArgumentException.class, () -> new ScoreQuantizingCollector(collector, 0));
|
||||
expectThrows(IllegalArgumentException.class, () -> new ScoreQuantizingCollector(collector, 24));
|
||||
for (int i = 1; i <= 23; ++i) {
|
||||
new ScoreQuantizingCollector(collector, i); // no exception
|
||||
}
|
||||
}
|
||||
|
||||
public void testQuantize() {
|
||||
ScoreQuantizingCollector collector =
|
||||
new ScoreQuantizingCollector(TopScoreDocCollector.create(10, 10), 4);
|
||||
|
||||
assertEquals(1.125f, collector.roundDown(1.2345f), 0f);
|
||||
assertEquals(1.25f, collector.roundUp(1.2345f), 0f);
|
||||
|
||||
assertEquals(1.25f, collector.roundDown(1.25f), 0f);
|
||||
assertEquals(1.25f, collector.roundUp(1.25f), 0f);
|
||||
|
||||
assertEquals(0f, collector.roundDown(0f), 0f);
|
||||
assertEquals(0f, collector.roundUp(0f), 0f);
|
||||
|
||||
assertEquals(0f, collector.roundDown(Float.MIN_VALUE), 0f);
|
||||
assertEquals(0f, collector.roundUp(0f), 0f);
|
||||
|
||||
assertEquals(Float.MIN_NORMAL, collector.roundDown(Float.MIN_NORMAL), 0f);
|
||||
assertEquals(Float.MIN_NORMAL, collector.roundUp(Float.MIN_NORMAL), 0f);
|
||||
|
||||
assertEquals(3.1901472E38f, collector.roundDown(Float.MAX_VALUE), 0f);
|
||||
assertEquals(Float.POSITIVE_INFINITY, collector.roundUp(Float.MAX_VALUE), 0f);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue