From e0e5d81df802fba5570d40d44aec725c8dadfbaa Mon Sep 17 00:00:00 2001 From: Kaival Parikh <46070017+kaivalnp@users.noreply.github.com> Date: Tue, 6 Aug 2024 05:42:19 +0530 Subject: [PATCH] Add timeout support to AbstractVectorSimilarityQuery (#13285) Co-authored-by: Kaival Parikh --- lucene/CHANGES.txt | 3 + .../search/AbstractVectorSimilarityQuery.java | 120 ++++++++------ .../search/ByteVectorSimilarityQuery.java | 10 +- .../search/FloatVectorSimilarityQuery.java | 10 +- .../BaseVectorSimilarityQueryTestCase.java | 150 ++++++++++++++++++ 5 files changed, 237 insertions(+), 56 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 739887b0b91..265c339145e 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -288,6 +288,9 @@ Improvements * GITHUB#13625: Remove BitSet#nextSetBit code duplication. (Greg Miller) +* GITHUB#13285: Early terminate graph searches of AbstractVectorSimilarityQuery to follow timeout set from + IndexSearcher#setTimeout(QueryTimeout). (Kaival Parikh) + Optimizations --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java index 77a5ff6f24f..75d639c08fe 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java @@ -23,6 +23,8 @@ import java.util.List; import java.util.Objects; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.QueryTimeout; +import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.Bits; @@ -58,10 +60,19 @@ abstract class AbstractVectorSimilarityQuery extends Query { this.filter = filter; } + protected KnnCollectorManager getKnnCollectorManager() { + return (visitedLimit, context) -> + new VectorSimilarityCollector(traversalSimilarity, resultSimilarity, visitedLimit); + } + abstract VectorScorer createVectorScorer(LeafReaderContext context) throws IOException; protected abstract TopDocs approximateSearch( - LeafReaderContext context, Bits acceptDocs, int visitLimit) throws IOException; + LeafReaderContext context, + Bits acceptDocs, + int visitLimit, + KnnCollectorManager knnCollectorManager) + throws IOException; @Override public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) @@ -72,6 +83,10 @@ abstract class AbstractVectorSimilarityQuery extends Query { ? null : searcher.createWeight(searcher.rewrite(filter), ScoreMode.COMPLETE_NO_SCORES, 1); + final QueryTimeout queryTimeout = searcher.getTimeout(); + final TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager = + new TimeLimitingKnnCollectorManager(getKnnCollectorManager(), queryTimeout); + @Override public Explanation explain(LeafReaderContext context, int doc) throws IOException { if (filterWeight != null) { @@ -103,16 +118,14 @@ abstract class AbstractVectorSimilarityQuery extends Query { public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { LeafReader leafReader = context.reader(); Bits liveDocs = leafReader.getLiveDocs(); - final Scorer vectorSimilarityScorer; + // If there is no filter if (filterWeight == null) { // Return exhaustive results - TopDocs results = approximateSearch(context, liveDocs, Integer.MAX_VALUE); - if (results.scoreDocs.length == 0) { - return null; - } - vectorSimilarityScorer = - VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs); + TopDocs results = + approximateSearch( + context, liveDocs, Integer.MAX_VALUE, timeLimitingKnnCollectorManager); + return VectorSimilarityScorerSupplier.fromScoreDocs(boost, results.scoreDocs); } else { Scorer scorer = filterWeight.scorer(context); if (scorer == null) { @@ -143,27 +156,23 @@ abstract class AbstractVectorSimilarityQuery extends Query { } // Perform an approximate search - TopDocs results = approximateSearch(context, acceptDocs, cardinality); + TopDocs results = + approximateSearch(context, acceptDocs, cardinality, timeLimitingKnnCollectorManager); - // If the limit was exhausted - if (results.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO) { - // Return a lazy-loading iterator - vectorSimilarityScorer = - VectorSimilarityScorer.fromAcceptDocs( - this, - boost, - createVectorScorer(context), - new BitSetIterator(acceptDocs, cardinality), - resultSimilarity); - } else if (results.scoreDocs.length == 0) { - return null; - } else { + if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO + // Return partial results only when timeout is met + || (queryTimeout != null && queryTimeout.shouldExit())) { // Return an iterator over the collected results - vectorSimilarityScorer = - VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs); + return VectorSimilarityScorerSupplier.fromScoreDocs(boost, results.scoreDocs); + } else { + // Return a lazy-loading iterator + return VectorSimilarityScorerSupplier.fromAcceptDocs( + boost, + createVectorScorer(context), + new BitSetIterator(acceptDocs, cardinality), + resultSimilarity); } } - return new DefaultScorerSupplier(vectorSimilarityScorer); } @Override @@ -197,16 +206,20 @@ abstract class AbstractVectorSimilarityQuery extends Query { return Objects.hash(field, traversalSimilarity, resultSimilarity, filter); } - private static class VectorSimilarityScorer extends Scorer { + private static class VectorSimilarityScorerSupplier extends ScorerSupplier { final DocIdSetIterator iterator; final float[] cachedScore; - VectorSimilarityScorer(DocIdSetIterator iterator, float[] cachedScore) { + VectorSimilarityScorerSupplier(DocIdSetIterator iterator, float[] cachedScore) { this.iterator = iterator; this.cachedScore = cachedScore; } - static VectorSimilarityScorer fromScoreDocs(Weight weight, float boost, ScoreDoc[] scoreDocs) { + static VectorSimilarityScorerSupplier fromScoreDocs(float boost, ScoreDoc[] scoreDocs) { + if (scoreDocs.length == 0) { + return null; + } + // Sort in ascending order of docid Arrays.sort(scoreDocs, Comparator.comparingInt(scoreDoc -> scoreDoc.doc)); @@ -252,18 +265,15 @@ abstract class AbstractVectorSimilarityQuery extends Query { } }; - return new VectorSimilarityScorer(iterator, cachedScore); + return new VectorSimilarityScorerSupplier(iterator, cachedScore); } - static VectorSimilarityScorer fromAcceptDocs( - Weight weight, - float boost, - VectorScorer scorer, - DocIdSetIterator acceptDocs, - float threshold) { + static VectorSimilarityScorerSupplier fromAcceptDocs( + float boost, VectorScorer scorer, DocIdSetIterator acceptDocs, float threshold) { if (scorer == null) { return null; } + float[] cachedScore = new float[1]; DocIdSetIterator vectorIterator = scorer.iterator(); DocIdSetIterator conjunction = @@ -281,27 +291,37 @@ abstract class AbstractVectorSimilarityQuery extends Query { } }; - return new VectorSimilarityScorer(iterator, cachedScore); + return new VectorSimilarityScorerSupplier(iterator, cachedScore); } @Override - public int docID() { - return iterator.docID(); + public Scorer get(long leadCost) { + return new Scorer() { + @Override + public int docID() { + return iterator.docID(); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + + @Override + public float getMaxScore(int upTo) { + return Float.POSITIVE_INFINITY; + } + + @Override + public float score() { + return cachedScore[0]; + } + }; } @Override - public DocIdSetIterator iterator() { - return iterator; - } - - @Override - public float getMaxScore(int upTo) { - return Float.POSITIVE_INFINITY; - } - - @Override - public float score() { - return cachedScore[0]; + public long cost() { + return iterator.cost(); } } } diff --git a/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityQuery.java b/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityQuery.java index bd2190121ab..c547f1face7 100644 --- a/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityQuery.java @@ -23,6 +23,7 @@ import java.util.Objects; import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.util.Bits; /** @@ -106,10 +107,13 @@ public class ByteVectorSimilarityQuery extends AbstractVectorSimilarityQuery { @Override @SuppressWarnings("resource") - protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitLimit) + protected TopDocs approximateSearch( + LeafReaderContext context, + Bits acceptDocs, + int visitLimit, + KnnCollectorManager knnCollectorManager) throws IOException { - KnnCollector collector = - new VectorSimilarityCollector(traversalSimilarity, resultSimilarity, visitLimit); + KnnCollector collector = knnCollectorManager.newCollector(visitLimit, context); context.reader().searchNearestVectors(field, target, collector, acceptDocs); return collector.topDocs(); } diff --git a/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityQuery.java b/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityQuery.java index 3dc92482a77..4c7078ac140 100644 --- a/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityQuery.java @@ -23,6 +23,7 @@ import java.util.Objects; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.util.Bits; import org.apache.lucene.util.VectorUtil; @@ -108,10 +109,13 @@ public class FloatVectorSimilarityQuery extends AbstractVectorSimilarityQuery { @Override @SuppressWarnings("resource") - protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitLimit) + protected TopDocs approximateSearch( + LeafReaderContext context, + Bits acceptDocs, + int visitLimit, + KnnCollectorManager knnCollectorManager) throws IOException { - KnnCollector collector = - new VectorSimilarityCollector(traversalSimilarity, resultSimilarity, visitLimit); + KnnCollector collector = knnCollectorManager.newCollector(visitLimit, context); context.reader().searchNearestVectors(field, target, collector, acceptDocs); return collector.topDocs(); } diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseVectorSimilarityQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseVectorSimilarityQueryTestCase.java index 3347b9478dd..1e32c07b665 100644 --- a/lucene/core/src/test/org/apache/lucene/search/BaseVectorSimilarityQueryTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/search/BaseVectorSimilarityQueryTestCase.java @@ -16,6 +16,8 @@ */ package org.apache.lucene.search; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + import java.io.IOException; import java.util.Arrays; import java.util.HashMap; @@ -32,6 +34,8 @@ import org.apache.lucene.document.IntField; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.QueryTimeout; import org.apache.lucene.index.Term; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.Directory; @@ -475,6 +479,62 @@ abstract class BaseVectorSimilarityQueryTestCase< } } + /** Test that the query times out correctly. */ + public void testTimeout() throws IOException { + V[] vectors = getRandomVectors(numDocs, dim); + V queryVector = getRandomVector(dim); + + try (Directory indexStore = getIndexStore(vectors); + IndexReader reader = DirectoryReader.open(indexStore)) { + IndexSearcher searcher = newSearcher(reader); + + // This query is cacheable, explicitly prevent it + searcher.setQueryCache(null); + + Query query = + new CountingQuery( + getVectorQuery( + vectorField, + queryVector, + Float.NEGATIVE_INFINITY, + Float.NEGATIVE_INFINITY, + null)); + + assertEquals(numDocs, searcher.count(query)); // Expect some results without timeout + + searcher.setTimeout(() -> true); // Immediately timeout + assertEquals(0, searcher.count(query)); // Expect no results with the timeout + + searcher.setTimeout(new CountingQueryTimeout(numDocs - 1)); // Do not score all docs + int count = searcher.count(query); + assertTrue( + "0 < count=" + count + " < numDocs=" + numDocs, + count > 0 && count < numDocs); // Expect partial results + + // Test timeout with filter + int numFiltered = random().nextInt(numDocs / 2, numDocs); + Query filter = IntField.newSetQuery(idField, getFiltered(numFiltered)); + Query filteredQuery = + new CountingQuery( + getVectorQuery( + vectorField, + queryVector, + Float.NEGATIVE_INFINITY, + Float.NEGATIVE_INFINITY, + filter)); + + searcher.setTimeout(() -> false); // Set a timeout which is never met + assertEquals(numFiltered, searcher.count(filteredQuery)); + + searcher.setTimeout( + new CountingQueryTimeout(numFiltered - 1)); // Timeout before scoring all filtered docs + int filteredCount = searcher.count(filteredQuery); + assertTrue( + "0 < filteredCount=" + filteredCount + " < numFiltered=" + numFiltered, + filteredCount > 0 && filteredCount < numFiltered); // Expect partial results + } + } + private float getSimilarity(V[] vectors, V queryVector, int targetVisited) { assertTrue(targetVisited >= 0 && targetVisited <= numDocs); if (targetVisited == 0) { @@ -526,4 +586,94 @@ abstract class BaseVectorSimilarityQueryTestCase< } return dir; } + + private static class CountingQueryTimeout implements QueryTimeout { + private int remaining; + + public CountingQueryTimeout(int count) { + remaining = count; + } + + @Override + public boolean shouldExit() { + if (remaining > 0) { + remaining--; + return false; + } + return true; + } + } + + /** + * A {@link Query} that emulates {@link Weight#count(LeafReaderContext)} by counting number of + * docs of underlying {@link Scorer#iterator()}. TODO: This is a workaround to count partial + * results of {@link #delegate} because {@link TimeLimitingBulkScorer} immediately discards + * results after timeout. + */ + private static class CountingQuery extends Query { + private final Query delegate; + + private CountingQuery(Query delegate) { + this.delegate = delegate; + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) + throws IOException { + return new Weight(this) { + final Weight delegateWeight = delegate.createWeight(searcher, scoreMode, boost); + + @Override + public Explanation explain(LeafReaderContext context, int doc) throws IOException { + return delegateWeight.explain(context, doc); + } + + @Override + public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { + return delegateWeight.scorerSupplier(context); + } + + @Override + public int count(LeafReaderContext context) throws IOException { + Scorer scorer = scorer(context); + if (scorer == null) { + return 0; + } + + int count = 0; + DocIdSetIterator iterator = scorer.iterator(); + while (iterator.nextDoc() != NO_MORE_DOCS) { + count++; + } + return count; + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return delegateWeight.isCacheable(ctx); + } + }; + } + + @Override + public String toString(String field) { + return String.format( + Locale.ROOT, "%s[%s]", getClass().getSimpleName(), delegate.toString(field)); + } + + @Override + public void visit(QueryVisitor visitor) { + visitor.visitLeaf(this); + } + + @Override + public boolean equals(Object obj) { + return sameClassAs(obj) && delegate.equals(((CountingQuery) obj).delegate); + } + + @Override + public int hashCode() { + return delegate.hashCode(); + } + } }