Simplify max score for kNN vector queries (#12146)

The helper class DocAndScoreQuery implements advanceShallow to help skip
non-competitive documents. This method doesn't actually keep track of where it
has advanced, which means it can do extra work.

Overall the complexity here didn't seem worth it, given the low cost of
collecting matching kNN docs. This PR switches to a simpler approach, which uses
a fixed upper bound on the max score.
This commit is contained in:
Julie Tibshirani 2023-02-16 12:03:59 -08:00 committed by GitHub
parent 8e15c665be
commit 8340b01c3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 32 additions and 79 deletions

View File

@ -189,6 +189,10 @@ abstract class AbstractKnnVectorQuery extends Query {
private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
int len = topK.scoreDocs.length;
assert len > 0;
float maxScore = topK.scoreDocs[0].score;
Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
int[] docs = new int[len];
float[] scores = new float[len];
@ -197,7 +201,7 @@ abstract class AbstractKnnVectorQuery extends Query {
scores[i] = topK.scoreDocs[i].score;
}
int[] segmentStarts = findSegmentStarts(reader, docs);
return new DocAndScoreQuery(docs, scores, segmentStarts, reader.getContext().id());
return new DocAndScoreQuery(docs, scores, maxScore, segmentStarts, reader.getContext().id());
}
private int[] findSegmentStarts(IndexReader reader, int[] docs) {
@ -265,6 +269,7 @@ abstract class AbstractKnnVectorQuery extends Query {
private final int[] docs;
private final float[] scores;
private final float maxScore;
private final int[] segmentStarts;
private final Object contextIdentity;
@ -280,9 +285,11 @@ abstract class AbstractKnnVectorQuery extends Query {
* @param contextIdentity an object identifying the reader context that was used to build this
* query
*/
DocAndScoreQuery(int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) {
DocAndScoreQuery(
int[] docs, float[] scores, float maxScore, int[] segmentStarts, Object contextIdentity) {
this.docs = docs;
this.scores = scores;
this.maxScore = maxScore;
this.segmentStarts = segmentStarts;
this.contextIdentity = contextIdentity;
}
@ -343,11 +350,6 @@ abstract class AbstractKnnVectorQuery extends Query {
@Override
public float getMaxScore(int docId) {
docId += context.docBase;
float maxScore = 0;
for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docId; idx++) {
maxScore = Math.max(maxScore, scores[idx]);
}
return maxScore * boost;
}
@ -356,19 +358,6 @@ abstract class AbstractKnnVectorQuery extends Query {
return scores[upTo] * boost;
}
@Override
public int advanceShallow(int docid) {
int start = Math.max(upTo, lower);
int docidIndex = Arrays.binarySearch(docs, start, upper, docid + context.docBase);
if (docidIndex < 0) {
docidIndex = -1 - docidIndex;
}
if (docidIndex >= upper) {
return NO_MORE_DOCS;
}
return docs[docidIndex];
}
/**
* move the implementation of docID() into a differently-named method so we can call it
* from DocIDSetIterator.docID() even though this class is anonymous

View File

@ -244,36 +244,6 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
}
}
public void testAdvanceShallow() throws IOException {
try (Directory d = newDirectory()) {
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
for (int j = 0; j < 5; j++) {
Document doc = new Document();
doc.add(getKnnVectorField("field", new float[] {j, j}));
w.addDocument(doc);
}
}
try (IndexReader reader = DirectoryReader.open(d)) {
IndexSearcher searcher = new IndexSearcher(reader);
AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {2, 3}, 3);
Query dasq = query.rewrite(searcher);
Scorer scorer =
dasq.createWeight(searcher, ScoreMode.COMPLETE, 1).scorer(reader.leaves().get(0));
// before advancing the iterator
assertEquals(1, scorer.advanceShallow(0));
assertEquals(1, scorer.advanceShallow(1));
assertEquals(NO_MORE_DOCS, scorer.advanceShallow(10));
// after advancing the iterator
scorer.iterator().advance(2);
assertEquals(2, scorer.advanceShallow(0));
assertEquals(2, scorer.advanceShallow(2));
assertEquals(3, scorer.advanceShallow(3));
assertEquals(NO_MORE_DOCS, scorer.advanceShallow(10));
}
}
}
public void testScoreEuclidean() throws IOException {
float[][] vectors = new float[5][];
for (int j = 0; j < 5; j++) {
@ -291,9 +261,6 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
assertEquals(-1, scorer.docID());
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
// test getMaxScore
assertEquals(0, scorer.getMaxScore(-1), 0);
assertEquals(0, scorer.getMaxScore(0), 0);
// This is 1 / ((l2distance((2,3), (2, 2)) = 1) + 1) = 0.5
assertEquals(1 / 2f, scorer.getMaxScore(2), 0);
assertEquals(1 / 2f, scorer.getMaxScore(Integer.MAX_VALUE), 0);
@ -304,6 +271,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
assertEquals(1 / 6f, scorer.score(), 0);
assertEquals(3, it.advance(3));
assertEquals(1 / 2f, scorer.score(), 0);
assertEquals(NO_MORE_DOCS, it.advance(4));
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
}
@ -330,32 +298,30 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
assertEquals(-1, scorer.docID());
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
// test getMaxScore
assertEquals(0, scorer.getMaxScore(-1), 0);
/* maxAtZero = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)), then
/* score0 = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)), then
* normalized by (1 + x) /2.
*/
float maxAtZero =
float score0 =
(float) ((1 + (2 * 1 + 3 * 1) / Math.sqrt((2 * 2 + 3 * 3) * (1 * 1 + 1 * 1))) / 2);
assertEquals(maxAtZero, scorer.getMaxScore(0), 0.001);
/* max at 2 is actually the score for doc 1 which is the highest (since doc 1 vector (2, 4)
* is the closest to (2, 3)). This is ((2,3) * (2, 4) = 16) / (||2, 3|| * ||2, 4|| = sqrt(260)), then
/* score1 = ((2,3) * (2, 4) = 16) / (||2, 3|| * ||2, 4|| = sqrt(260)), then
* normalized by (1 + x) /2
*/
float expected =
float score1 =
(float) ((1 + (2 * 2 + 3 * 4) / Math.sqrt((2 * 2 + 3 * 3) * (2 * 2 + 4 * 4))) / 2);
assertEquals(expected, scorer.getMaxScore(2), 0);
assertEquals(expected, scorer.getMaxScore(Integer.MAX_VALUE), 0);
// doc 1 happens to have the maximum score
assertEquals(score1, scorer.getMaxScore(2), 0.0001);
assertEquals(score1, scorer.getMaxScore(Integer.MAX_VALUE), 0.0001);
DocIdSetIterator it = scorer.iterator();
assertEquals(3, it.cost());
assertEquals(0, it.nextDoc());
// doc 0 has (1, 1)
assertEquals(maxAtZero, scorer.score(), 0.0001);
assertEquals(score0, scorer.score(), 0.0001);
assertEquals(1, it.advance(1));
assertEquals(expected, scorer.score(), 0);
assertEquals(2, it.nextDoc());
assertEquals(score1, scorer.score(), 0.0001);
// since topK was 3
assertEquals(NO_MORE_DOCS, it.advance(4));
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);

View File

@ -133,32 +133,30 @@ public class TestKnnFloatVectorQuery extends BaseKnnVectorQueryTestCase {
assertEquals(-1, scorer.docID());
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
// test getMaxScore
assertEquals(0, scorer.getMaxScore(-1), 0);
/* maxAtZero = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)), then
/* score0 = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)), then
* normalized by (1 + x) /2.
*/
float maxAtZero =
float score0 =
(float) ((1 + (2 * 1 + 3 * 1) / Math.sqrt((2 * 2 + 3 * 3) * (1 * 1 + 1 * 1))) / 2);
assertEquals(maxAtZero, scorer.getMaxScore(0), 0.001);
/* max at 2 is actually the score for doc 1 which is the highest (since doc 1 vector (2, 4)
* is the closest to (2, 3)). This is ((2,3) * (2, 4) = 16) / (||2, 3|| * ||2, 4|| = sqrt(260)), then
/* score1 = ((2,3) * (2, 4) = 16) / (||2, 3|| * ||2, 4|| = sqrt(260)), then
* normalized by (1 + x) /2
*/
float expected =
float score1 =
(float) ((1 + (2 * 2 + 3 * 4) / Math.sqrt((2 * 2 + 3 * 3) * (2 * 2 + 4 * 4))) / 2);
assertEquals(expected, scorer.getMaxScore(2), 0);
assertEquals(expected, scorer.getMaxScore(Integer.MAX_VALUE), 0);
// doc 1 happens to have the max score
assertEquals(score1, scorer.getMaxScore(2), 0.0001);
assertEquals(score1, scorer.getMaxScore(Integer.MAX_VALUE), 0.0001);
DocIdSetIterator it = scorer.iterator();
assertEquals(3, it.cost());
assertEquals(0, it.nextDoc());
// doc 0 has (1, 1)
assertEquals(maxAtZero, scorer.score(), 0.0001);
assertEquals(score0, scorer.score(), 0.0001);
assertEquals(1, it.advance(1));
assertEquals(expected, scorer.score(), 0);
assertEquals(2, it.nextDoc());
assertEquals(score1, scorer.score(), 0.0001);
// since topK was 3
assertEquals(NO_MORE_DOCS, it.advance(4));
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);