mirror of https://github.com/apache/lucene.git
Simplify max score for kNN vector queries (#12146)
The helper class DocAndScoreQuery implements advanceShallow to help skip non-competitive documents. This method doesn't actually keep track of where it has advanced, which means it can do extra work. Overall the complexity here didn't seem worth it, given the low cost of collecting matching kNN docs. This PR switches to a simpler approach, which uses a fixed upper bound on the max score.
This commit is contained in:
parent
8e15c665be
commit
8340b01c3c
|
@ -189,6 +189,10 @@ abstract class AbstractKnnVectorQuery extends Query {
|
||||||
|
|
||||||
private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
|
private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
|
||||||
int len = topK.scoreDocs.length;
|
int len = topK.scoreDocs.length;
|
||||||
|
|
||||||
|
assert len > 0;
|
||||||
|
float maxScore = topK.scoreDocs[0].score;
|
||||||
|
|
||||||
Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
|
Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
|
||||||
int[] docs = new int[len];
|
int[] docs = new int[len];
|
||||||
float[] scores = new float[len];
|
float[] scores = new float[len];
|
||||||
|
@ -197,7 +201,7 @@ abstract class AbstractKnnVectorQuery extends Query {
|
||||||
scores[i] = topK.scoreDocs[i].score;
|
scores[i] = topK.scoreDocs[i].score;
|
||||||
}
|
}
|
||||||
int[] segmentStarts = findSegmentStarts(reader, docs);
|
int[] segmentStarts = findSegmentStarts(reader, docs);
|
||||||
return new DocAndScoreQuery(docs, scores, segmentStarts, reader.getContext().id());
|
return new DocAndScoreQuery(docs, scores, maxScore, segmentStarts, reader.getContext().id());
|
||||||
}
|
}
|
||||||
|
|
||||||
private int[] findSegmentStarts(IndexReader reader, int[] docs) {
|
private int[] findSegmentStarts(IndexReader reader, int[] docs) {
|
||||||
|
@ -265,6 +269,7 @@ abstract class AbstractKnnVectorQuery extends Query {
|
||||||
|
|
||||||
private final int[] docs;
|
private final int[] docs;
|
||||||
private final float[] scores;
|
private final float[] scores;
|
||||||
|
private final float maxScore;
|
||||||
private final int[] segmentStarts;
|
private final int[] segmentStarts;
|
||||||
private final Object contextIdentity;
|
private final Object contextIdentity;
|
||||||
|
|
||||||
|
@ -280,9 +285,11 @@ abstract class AbstractKnnVectorQuery extends Query {
|
||||||
* @param contextIdentity an object identifying the reader context that was used to build this
|
* @param contextIdentity an object identifying the reader context that was used to build this
|
||||||
* query
|
* query
|
||||||
*/
|
*/
|
||||||
DocAndScoreQuery(int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) {
|
DocAndScoreQuery(
|
||||||
|
int[] docs, float[] scores, float maxScore, int[] segmentStarts, Object contextIdentity) {
|
||||||
this.docs = docs;
|
this.docs = docs;
|
||||||
this.scores = scores;
|
this.scores = scores;
|
||||||
|
this.maxScore = maxScore;
|
||||||
this.segmentStarts = segmentStarts;
|
this.segmentStarts = segmentStarts;
|
||||||
this.contextIdentity = contextIdentity;
|
this.contextIdentity = contextIdentity;
|
||||||
}
|
}
|
||||||
|
@ -343,11 +350,6 @@ abstract class AbstractKnnVectorQuery extends Query {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public float getMaxScore(int docId) {
|
public float getMaxScore(int docId) {
|
||||||
docId += context.docBase;
|
|
||||||
float maxScore = 0;
|
|
||||||
for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docId; idx++) {
|
|
||||||
maxScore = Math.max(maxScore, scores[idx]);
|
|
||||||
}
|
|
||||||
return maxScore * boost;
|
return maxScore * boost;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -356,19 +358,6 @@ abstract class AbstractKnnVectorQuery extends Query {
|
||||||
return scores[upTo] * boost;
|
return scores[upTo] * boost;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public int advanceShallow(int docid) {
|
|
||||||
int start = Math.max(upTo, lower);
|
|
||||||
int docidIndex = Arrays.binarySearch(docs, start, upper, docid + context.docBase);
|
|
||||||
if (docidIndex < 0) {
|
|
||||||
docidIndex = -1 - docidIndex;
|
|
||||||
}
|
|
||||||
if (docidIndex >= upper) {
|
|
||||||
return NO_MORE_DOCS;
|
|
||||||
}
|
|
||||||
return docs[docidIndex];
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* move the implementation of docID() into a differently-named method so we can call it
|
* move the implementation of docID() into a differently-named method so we can call it
|
||||||
* from DocIDSetIterator.docID() even though this class is anonymous
|
* from DocIDSetIterator.docID() even though this class is anonymous
|
||||||
|
|
|
@ -244,36 +244,6 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testAdvanceShallow() throws IOException {
|
|
||||||
try (Directory d = newDirectory()) {
|
|
||||||
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
|
|
||||||
for (int j = 0; j < 5; j++) {
|
|
||||||
Document doc = new Document();
|
|
||||||
doc.add(getKnnVectorField("field", new float[] {j, j}));
|
|
||||||
w.addDocument(doc);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
try (IndexReader reader = DirectoryReader.open(d)) {
|
|
||||||
IndexSearcher searcher = new IndexSearcher(reader);
|
|
||||||
AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {2, 3}, 3);
|
|
||||||
Query dasq = query.rewrite(searcher);
|
|
||||||
Scorer scorer =
|
|
||||||
dasq.createWeight(searcher, ScoreMode.COMPLETE, 1).scorer(reader.leaves().get(0));
|
|
||||||
// before advancing the iterator
|
|
||||||
assertEquals(1, scorer.advanceShallow(0));
|
|
||||||
assertEquals(1, scorer.advanceShallow(1));
|
|
||||||
assertEquals(NO_MORE_DOCS, scorer.advanceShallow(10));
|
|
||||||
|
|
||||||
// after advancing the iterator
|
|
||||||
scorer.iterator().advance(2);
|
|
||||||
assertEquals(2, scorer.advanceShallow(0));
|
|
||||||
assertEquals(2, scorer.advanceShallow(2));
|
|
||||||
assertEquals(3, scorer.advanceShallow(3));
|
|
||||||
assertEquals(NO_MORE_DOCS, scorer.advanceShallow(10));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public void testScoreEuclidean() throws IOException {
|
public void testScoreEuclidean() throws IOException {
|
||||||
float[][] vectors = new float[5][];
|
float[][] vectors = new float[5][];
|
||||||
for (int j = 0; j < 5; j++) {
|
for (int j = 0; j < 5; j++) {
|
||||||
|
@ -291,9 +261,6 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
|
||||||
assertEquals(-1, scorer.docID());
|
assertEquals(-1, scorer.docID());
|
||||||
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
|
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
|
||||||
|
|
||||||
// test getMaxScore
|
|
||||||
assertEquals(0, scorer.getMaxScore(-1), 0);
|
|
||||||
assertEquals(0, scorer.getMaxScore(0), 0);
|
|
||||||
// This is 1 / ((l2distance((2,3), (2, 2)) = 1) + 1) = 0.5
|
// This is 1 / ((l2distance((2,3), (2, 2)) = 1) + 1) = 0.5
|
||||||
assertEquals(1 / 2f, scorer.getMaxScore(2), 0);
|
assertEquals(1 / 2f, scorer.getMaxScore(2), 0);
|
||||||
assertEquals(1 / 2f, scorer.getMaxScore(Integer.MAX_VALUE), 0);
|
assertEquals(1 / 2f, scorer.getMaxScore(Integer.MAX_VALUE), 0);
|
||||||
|
@ -304,6 +271,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
|
||||||
assertEquals(1 / 6f, scorer.score(), 0);
|
assertEquals(1 / 6f, scorer.score(), 0);
|
||||||
assertEquals(3, it.advance(3));
|
assertEquals(3, it.advance(3));
|
||||||
assertEquals(1 / 2f, scorer.score(), 0);
|
assertEquals(1 / 2f, scorer.score(), 0);
|
||||||
|
|
||||||
assertEquals(NO_MORE_DOCS, it.advance(4));
|
assertEquals(NO_MORE_DOCS, it.advance(4));
|
||||||
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
|
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
|
||||||
}
|
}
|
||||||
|
@ -330,32 +298,30 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
|
||||||
assertEquals(-1, scorer.docID());
|
assertEquals(-1, scorer.docID());
|
||||||
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
|
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
|
||||||
|
|
||||||
// test getMaxScore
|
/* score0 = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)), then
|
||||||
assertEquals(0, scorer.getMaxScore(-1), 0);
|
|
||||||
/* maxAtZero = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)), then
|
|
||||||
* normalized by (1 + x) /2.
|
* normalized by (1 + x) /2.
|
||||||
*/
|
*/
|
||||||
float maxAtZero =
|
float score0 =
|
||||||
(float) ((1 + (2 * 1 + 3 * 1) / Math.sqrt((2 * 2 + 3 * 3) * (1 * 1 + 1 * 1))) / 2);
|
(float) ((1 + (2 * 1 + 3 * 1) / Math.sqrt((2 * 2 + 3 * 3) * (1 * 1 + 1 * 1))) / 2);
|
||||||
assertEquals(maxAtZero, scorer.getMaxScore(0), 0.001);
|
|
||||||
|
|
||||||
/* max at 2 is actually the score for doc 1 which is the highest (since doc 1 vector (2, 4)
|
/* score1 = ((2,3) * (2, 4) = 16) / (||2, 3|| * ||2, 4|| = sqrt(260)), then
|
||||||
* is the closest to (2, 3)). This is ((2,3) * (2, 4) = 16) / (||2, 3|| * ||2, 4|| = sqrt(260)), then
|
|
||||||
* normalized by (1 + x) /2
|
* normalized by (1 + x) /2
|
||||||
*/
|
*/
|
||||||
float expected =
|
float score1 =
|
||||||
(float) ((1 + (2 * 2 + 3 * 4) / Math.sqrt((2 * 2 + 3 * 3) * (2 * 2 + 4 * 4))) / 2);
|
(float) ((1 + (2 * 2 + 3 * 4) / Math.sqrt((2 * 2 + 3 * 3) * (2 * 2 + 4 * 4))) / 2);
|
||||||
assertEquals(expected, scorer.getMaxScore(2), 0);
|
|
||||||
assertEquals(expected, scorer.getMaxScore(Integer.MAX_VALUE), 0);
|
// doc 1 happens to have the maximum score
|
||||||
|
assertEquals(score1, scorer.getMaxScore(2), 0.0001);
|
||||||
|
assertEquals(score1, scorer.getMaxScore(Integer.MAX_VALUE), 0.0001);
|
||||||
|
|
||||||
DocIdSetIterator it = scorer.iterator();
|
DocIdSetIterator it = scorer.iterator();
|
||||||
assertEquals(3, it.cost());
|
assertEquals(3, it.cost());
|
||||||
assertEquals(0, it.nextDoc());
|
assertEquals(0, it.nextDoc());
|
||||||
// doc 0 has (1, 1)
|
// doc 0 has (1, 1)
|
||||||
assertEquals(maxAtZero, scorer.score(), 0.0001);
|
assertEquals(score0, scorer.score(), 0.0001);
|
||||||
assertEquals(1, it.advance(1));
|
assertEquals(1, it.advance(1));
|
||||||
assertEquals(expected, scorer.score(), 0);
|
assertEquals(score1, scorer.score(), 0.0001);
|
||||||
assertEquals(2, it.nextDoc());
|
|
||||||
// since topK was 3
|
// since topK was 3
|
||||||
assertEquals(NO_MORE_DOCS, it.advance(4));
|
assertEquals(NO_MORE_DOCS, it.advance(4));
|
||||||
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
|
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
|
||||||
|
|
|
@ -133,32 +133,30 @@ public class TestKnnFloatVectorQuery extends BaseKnnVectorQueryTestCase {
|
||||||
assertEquals(-1, scorer.docID());
|
assertEquals(-1, scorer.docID());
|
||||||
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
|
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
|
||||||
|
|
||||||
// test getMaxScore
|
/* score0 = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)), then
|
||||||
assertEquals(0, scorer.getMaxScore(-1), 0);
|
|
||||||
/* maxAtZero = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)), then
|
|
||||||
* normalized by (1 + x) /2.
|
* normalized by (1 + x) /2.
|
||||||
*/
|
*/
|
||||||
float maxAtZero =
|
float score0 =
|
||||||
(float) ((1 + (2 * 1 + 3 * 1) / Math.sqrt((2 * 2 + 3 * 3) * (1 * 1 + 1 * 1))) / 2);
|
(float) ((1 + (2 * 1 + 3 * 1) / Math.sqrt((2 * 2 + 3 * 3) * (1 * 1 + 1 * 1))) / 2);
|
||||||
assertEquals(maxAtZero, scorer.getMaxScore(0), 0.001);
|
|
||||||
|
|
||||||
/* max at 2 is actually the score for doc 1 which is the highest (since doc 1 vector (2, 4)
|
/* score1 = ((2,3) * (2, 4) = 16) / (||2, 3|| * ||2, 4|| = sqrt(260)), then
|
||||||
* is the closest to (2, 3)). This is ((2,3) * (2, 4) = 16) / (||2, 3|| * ||2, 4|| = sqrt(260)), then
|
|
||||||
* normalized by (1 + x) /2
|
* normalized by (1 + x) /2
|
||||||
*/
|
*/
|
||||||
float expected =
|
float score1 =
|
||||||
(float) ((1 + (2 * 2 + 3 * 4) / Math.sqrt((2 * 2 + 3 * 3) * (2 * 2 + 4 * 4))) / 2);
|
(float) ((1 + (2 * 2 + 3 * 4) / Math.sqrt((2 * 2 + 3 * 3) * (2 * 2 + 4 * 4))) / 2);
|
||||||
assertEquals(expected, scorer.getMaxScore(2), 0);
|
|
||||||
assertEquals(expected, scorer.getMaxScore(Integer.MAX_VALUE), 0);
|
// doc 1 happens to have the max score
|
||||||
|
assertEquals(score1, scorer.getMaxScore(2), 0.0001);
|
||||||
|
assertEquals(score1, scorer.getMaxScore(Integer.MAX_VALUE), 0.0001);
|
||||||
|
|
||||||
DocIdSetIterator it = scorer.iterator();
|
DocIdSetIterator it = scorer.iterator();
|
||||||
assertEquals(3, it.cost());
|
assertEquals(3, it.cost());
|
||||||
assertEquals(0, it.nextDoc());
|
assertEquals(0, it.nextDoc());
|
||||||
// doc 0 has (1, 1)
|
// doc 0 has (1, 1)
|
||||||
assertEquals(maxAtZero, scorer.score(), 0.0001);
|
assertEquals(score0, scorer.score(), 0.0001);
|
||||||
assertEquals(1, it.advance(1));
|
assertEquals(1, it.advance(1));
|
||||||
assertEquals(expected, scorer.score(), 0);
|
assertEquals(score1, scorer.score(), 0.0001);
|
||||||
assertEquals(2, it.nextDoc());
|
|
||||||
// since topK was 3
|
// since topK was 3
|
||||||
assertEquals(NO_MORE_DOCS, it.advance(4));
|
assertEquals(NO_MORE_DOCS, it.advance(4));
|
||||||
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
|
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
|
||||||
|
|
Loading…
Reference in New Issue