From 8340b01c3cc229f33584ce2178b07b8984daa6a9 Mon Sep 17 00:00:00 2001 From: Julie Tibshirani Date: Thu, 16 Feb 2023 12:03:59 -0800 Subject: [PATCH] Simplify max score for kNN vector queries (#12146) The helper class DocAndScoreQuery implements advanceShallow to help skip non-competitive documents. This method doesn't actually keep track of where it has advanced, which means it can do extra work. Overall the complexity here didn't seem worth it, given the low cost of collecting matching kNN docs. This PR switches to a simpler approach, which uses a fixed upper bound on the max score. --- .../lucene/search/AbstractKnnVectorQuery.java | 29 +++------- .../search/BaseKnnVectorQueryTestCase.java | 58 ++++--------------- .../search/TestKnnFloatVectorQuery.java | 24 ++++---- 3 files changed, 32 insertions(+), 79 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index 6e348fcc5ee..f2e4b125f98 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -189,6 +189,10 @@ abstract class AbstractKnnVectorQuery extends Query { private Query createRewrittenQuery(IndexReader reader, TopDocs topK) { int len = topK.scoreDocs.length; + + assert len > 0; + float maxScore = topK.scoreDocs[0].score; + Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc)); int[] docs = new int[len]; float[] scores = new float[len]; @@ -197,7 +201,7 @@ abstract class AbstractKnnVectorQuery extends Query { scores[i] = topK.scoreDocs[i].score; } int[] segmentStarts = findSegmentStarts(reader, docs); - return new DocAndScoreQuery(docs, scores, segmentStarts, reader.getContext().id()); + return new DocAndScoreQuery(docs, scores, maxScore, segmentStarts, reader.getContext().id()); } private int[] findSegmentStarts(IndexReader reader, int[] docs) { @@ -265,6 +269,7 @@ abstract class AbstractKnnVectorQuery extends Query { private final int[] docs; private final float[] scores; + private final float maxScore; private final int[] segmentStarts; private final Object contextIdentity; @@ -280,9 +285,11 @@ abstract class AbstractKnnVectorQuery extends Query { * @param contextIdentity an object identifying the reader context that was used to build this * query */ - DocAndScoreQuery(int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) { + DocAndScoreQuery( + int[] docs, float[] scores, float maxScore, int[] segmentStarts, Object contextIdentity) { this.docs = docs; this.scores = scores; + this.maxScore = maxScore; this.segmentStarts = segmentStarts; this.contextIdentity = contextIdentity; } @@ -343,11 +350,6 @@ abstract class AbstractKnnVectorQuery extends Query { @Override public float getMaxScore(int docId) { - docId += context.docBase; - float maxScore = 0; - for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docId; idx++) { - maxScore = Math.max(maxScore, scores[idx]); - } return maxScore * boost; } @@ -356,19 +358,6 @@ abstract class AbstractKnnVectorQuery extends Query { return scores[upTo] * boost; } - @Override - public int advanceShallow(int docid) { - int start = Math.max(upTo, lower); - int docidIndex = Arrays.binarySearch(docs, start, upper, docid + context.docBase); - if (docidIndex < 0) { - docidIndex = -1 - docidIndex; - } - if (docidIndex >= upper) { - return NO_MORE_DOCS; - } - return docs[docidIndex]; - } - /** * move the implementation of docID() into a differently-named method so we can call it * from DocIDSetIterator.docID() even though this class is anonymous diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java index 88d551619a4..dbb13d9f058 100644 --- a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java @@ -244,36 +244,6 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase { } } - public void testAdvanceShallow() throws IOException { - try (Directory d = newDirectory()) { - try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) { - for (int j = 0; j < 5; j++) { - Document doc = new Document(); - doc.add(getKnnVectorField("field", new float[] {j, j})); - w.addDocument(doc); - } - } - try (IndexReader reader = DirectoryReader.open(d)) { - IndexSearcher searcher = new IndexSearcher(reader); - AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {2, 3}, 3); - Query dasq = query.rewrite(searcher); - Scorer scorer = - dasq.createWeight(searcher, ScoreMode.COMPLETE, 1).scorer(reader.leaves().get(0)); - // before advancing the iterator - assertEquals(1, scorer.advanceShallow(0)); - assertEquals(1, scorer.advanceShallow(1)); - assertEquals(NO_MORE_DOCS, scorer.advanceShallow(10)); - - // after advancing the iterator - scorer.iterator().advance(2); - assertEquals(2, scorer.advanceShallow(0)); - assertEquals(2, scorer.advanceShallow(2)); - assertEquals(3, scorer.advanceShallow(3)); - assertEquals(NO_MORE_DOCS, scorer.advanceShallow(10)); - } - } - } - public void testScoreEuclidean() throws IOException { float[][] vectors = new float[5][]; for (int j = 0; j < 5; j++) { @@ -291,9 +261,6 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase { assertEquals(-1, scorer.docID()); expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score); - // test getMaxScore - assertEquals(0, scorer.getMaxScore(-1), 0); - assertEquals(0, scorer.getMaxScore(0), 0); // This is 1 / ((l2distance((2,3), (2, 2)) = 1) + 1) = 0.5 assertEquals(1 / 2f, scorer.getMaxScore(2), 0); assertEquals(1 / 2f, scorer.getMaxScore(Integer.MAX_VALUE), 0); @@ -304,6 +271,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase { assertEquals(1 / 6f, scorer.score(), 0); assertEquals(3, it.advance(3)); assertEquals(1 / 2f, scorer.score(), 0); + assertEquals(NO_MORE_DOCS, it.advance(4)); expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score); } @@ -330,32 +298,30 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase { assertEquals(-1, scorer.docID()); expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score); - // test getMaxScore - assertEquals(0, scorer.getMaxScore(-1), 0); - /* maxAtZero = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)), then + /* score0 = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)), then * normalized by (1 + x) /2. */ - float maxAtZero = + float score0 = (float) ((1 + (2 * 1 + 3 * 1) / Math.sqrt((2 * 2 + 3 * 3) * (1 * 1 + 1 * 1))) / 2); - assertEquals(maxAtZero, scorer.getMaxScore(0), 0.001); - /* max at 2 is actually the score for doc 1 which is the highest (since doc 1 vector (2, 4) - * is the closest to (2, 3)). This is ((2,3) * (2, 4) = 16) / (||2, 3|| * ||2, 4|| = sqrt(260)), then + /* score1 = ((2,3) * (2, 4) = 16) / (||2, 3|| * ||2, 4|| = sqrt(260)), then * normalized by (1 + x) /2 */ - float expected = + float score1 = (float) ((1 + (2 * 2 + 3 * 4) / Math.sqrt((2 * 2 + 3 * 3) * (2 * 2 + 4 * 4))) / 2); - assertEquals(expected, scorer.getMaxScore(2), 0); - assertEquals(expected, scorer.getMaxScore(Integer.MAX_VALUE), 0); + + // doc 1 happens to have the maximum score + assertEquals(score1, scorer.getMaxScore(2), 0.0001); + assertEquals(score1, scorer.getMaxScore(Integer.MAX_VALUE), 0.0001); DocIdSetIterator it = scorer.iterator(); assertEquals(3, it.cost()); assertEquals(0, it.nextDoc()); // doc 0 has (1, 1) - assertEquals(maxAtZero, scorer.score(), 0.0001); + assertEquals(score0, scorer.score(), 0.0001); assertEquals(1, it.advance(1)); - assertEquals(expected, scorer.score(), 0); - assertEquals(2, it.nextDoc()); + assertEquals(score1, scorer.score(), 0.0001); + // since topK was 3 assertEquals(NO_MORE_DOCS, it.advance(4)); expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score); diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java index c4f10f874be..04f5a53b246 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java @@ -133,32 +133,30 @@ public class TestKnnFloatVectorQuery extends BaseKnnVectorQueryTestCase { assertEquals(-1, scorer.docID()); expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score); - // test getMaxScore - assertEquals(0, scorer.getMaxScore(-1), 0); - /* maxAtZero = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)), then + /* score0 = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)), then * normalized by (1 + x) /2. */ - float maxAtZero = + float score0 = (float) ((1 + (2 * 1 + 3 * 1) / Math.sqrt((2 * 2 + 3 * 3) * (1 * 1 + 1 * 1))) / 2); - assertEquals(maxAtZero, scorer.getMaxScore(0), 0.001); - /* max at 2 is actually the score for doc 1 which is the highest (since doc 1 vector (2, 4) - * is the closest to (2, 3)). This is ((2,3) * (2, 4) = 16) / (||2, 3|| * ||2, 4|| = sqrt(260)), then + /* score1 = ((2,3) * (2, 4) = 16) / (||2, 3|| * ||2, 4|| = sqrt(260)), then * normalized by (1 + x) /2 */ - float expected = + float score1 = (float) ((1 + (2 * 2 + 3 * 4) / Math.sqrt((2 * 2 + 3 * 3) * (2 * 2 + 4 * 4))) / 2); - assertEquals(expected, scorer.getMaxScore(2), 0); - assertEquals(expected, scorer.getMaxScore(Integer.MAX_VALUE), 0); + + // doc 1 happens to have the max score + assertEquals(score1, scorer.getMaxScore(2), 0.0001); + assertEquals(score1, scorer.getMaxScore(Integer.MAX_VALUE), 0.0001); DocIdSetIterator it = scorer.iterator(); assertEquals(3, it.cost()); assertEquals(0, it.nextDoc()); // doc 0 has (1, 1) - assertEquals(maxAtZero, scorer.score(), 0.0001); + assertEquals(score0, scorer.score(), 0.0001); assertEquals(1, it.advance(1)); - assertEquals(expected, scorer.score(), 0); - assertEquals(2, it.nextDoc()); + assertEquals(score1, scorer.score(), 0.0001); + // since topK was 3 assertEquals(NO_MORE_DOCS, it.advance(4)); expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);