Add timeout support to AbstractVectorSimilarityQuery (#13285)

Co-authored-by: Kaival Parikh <kaivalnp@amazon.com>
This commit is contained in:
Kaival Parikh 2024-08-06 05:42:19 +05:30 committed by GitHub
parent 43c80117dd
commit e0e5d81df8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 237 additions and 56 deletions

View File

@ -288,6 +288,9 @@ Improvements
* GITHUB#13625: Remove BitSet#nextSetBit code duplication. (Greg Miller)
* GITHUB#13285: Early terminate graph searches of AbstractVectorSimilarityQuery to follow timeout set from
IndexSearcher#setTimeout(QueryTimeout). (Kaival Parikh)
Optimizations
---------------------

View File

@ -23,6 +23,8 @@ import java.util.List;
import java.util.Objects;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
@ -58,10 +60,19 @@ abstract class AbstractVectorSimilarityQuery extends Query {
this.filter = filter;
}
protected KnnCollectorManager getKnnCollectorManager() {
return (visitedLimit, context) ->
new VectorSimilarityCollector(traversalSimilarity, resultSimilarity, visitedLimit);
}
abstract VectorScorer createVectorScorer(LeafReaderContext context) throws IOException;
protected abstract TopDocs approximateSearch(
LeafReaderContext context, Bits acceptDocs, int visitLimit) throws IOException;
LeafReaderContext context,
Bits acceptDocs,
int visitLimit,
KnnCollectorManager knnCollectorManager)
throws IOException;
@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
@ -72,6 +83,10 @@ abstract class AbstractVectorSimilarityQuery extends Query {
? null
: searcher.createWeight(searcher.rewrite(filter), ScoreMode.COMPLETE_NO_SCORES, 1);
final QueryTimeout queryTimeout = searcher.getTimeout();
final TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager =
new TimeLimitingKnnCollectorManager(getKnnCollectorManager(), queryTimeout);
@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
if (filterWeight != null) {
@ -103,16 +118,14 @@ abstract class AbstractVectorSimilarityQuery extends Query {
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
LeafReader leafReader = context.reader();
Bits liveDocs = leafReader.getLiveDocs();
final Scorer vectorSimilarityScorer;
// If there is no filter
if (filterWeight == null) {
// Return exhaustive results
TopDocs results = approximateSearch(context, liveDocs, Integer.MAX_VALUE);
if (results.scoreDocs.length == 0) {
return null;
}
vectorSimilarityScorer =
VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs);
TopDocs results =
approximateSearch(
context, liveDocs, Integer.MAX_VALUE, timeLimitingKnnCollectorManager);
return VectorSimilarityScorerSupplier.fromScoreDocs(boost, results.scoreDocs);
} else {
Scorer scorer = filterWeight.scorer(context);
if (scorer == null) {
@ -143,27 +156,23 @@ abstract class AbstractVectorSimilarityQuery extends Query {
}
// Perform an approximate search
TopDocs results = approximateSearch(context, acceptDocs, cardinality);
TopDocs results =
approximateSearch(context, acceptDocs, cardinality, timeLimitingKnnCollectorManager);
// If the limit was exhausted
if (results.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO) {
// Return a lazy-loading iterator
vectorSimilarityScorer =
VectorSimilarityScorer.fromAcceptDocs(
this,
boost,
createVectorScorer(context),
new BitSetIterator(acceptDocs, cardinality),
resultSimilarity);
} else if (results.scoreDocs.length == 0) {
return null;
} else {
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO
// Return partial results only when timeout is met
|| (queryTimeout != null && queryTimeout.shouldExit())) {
// Return an iterator over the collected results
vectorSimilarityScorer =
VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs);
return VectorSimilarityScorerSupplier.fromScoreDocs(boost, results.scoreDocs);
} else {
// Return a lazy-loading iterator
return VectorSimilarityScorerSupplier.fromAcceptDocs(
boost,
createVectorScorer(context),
new BitSetIterator(acceptDocs, cardinality),
resultSimilarity);
}
}
return new DefaultScorerSupplier(vectorSimilarityScorer);
}
@Override
@ -197,16 +206,20 @@ abstract class AbstractVectorSimilarityQuery extends Query {
return Objects.hash(field, traversalSimilarity, resultSimilarity, filter);
}
private static class VectorSimilarityScorer extends Scorer {
private static class VectorSimilarityScorerSupplier extends ScorerSupplier {
final DocIdSetIterator iterator;
final float[] cachedScore;
VectorSimilarityScorer(DocIdSetIterator iterator, float[] cachedScore) {
VectorSimilarityScorerSupplier(DocIdSetIterator iterator, float[] cachedScore) {
this.iterator = iterator;
this.cachedScore = cachedScore;
}
static VectorSimilarityScorer fromScoreDocs(Weight weight, float boost, ScoreDoc[] scoreDocs) {
static VectorSimilarityScorerSupplier fromScoreDocs(float boost, ScoreDoc[] scoreDocs) {
if (scoreDocs.length == 0) {
return null;
}
// Sort in ascending order of docid
Arrays.sort(scoreDocs, Comparator.comparingInt(scoreDoc -> scoreDoc.doc));
@ -252,18 +265,15 @@ abstract class AbstractVectorSimilarityQuery extends Query {
}
};
return new VectorSimilarityScorer(iterator, cachedScore);
return new VectorSimilarityScorerSupplier(iterator, cachedScore);
}
static VectorSimilarityScorer fromAcceptDocs(
Weight weight,
float boost,
VectorScorer scorer,
DocIdSetIterator acceptDocs,
float threshold) {
static VectorSimilarityScorerSupplier fromAcceptDocs(
float boost, VectorScorer scorer, DocIdSetIterator acceptDocs, float threshold) {
if (scorer == null) {
return null;
}
float[] cachedScore = new float[1];
DocIdSetIterator vectorIterator = scorer.iterator();
DocIdSetIterator conjunction =
@ -281,27 +291,37 @@ abstract class AbstractVectorSimilarityQuery extends Query {
}
};
return new VectorSimilarityScorer(iterator, cachedScore);
return new VectorSimilarityScorerSupplier(iterator, cachedScore);
}
@Override
public int docID() {
return iterator.docID();
public Scorer get(long leadCost) {
return new Scorer() {
@Override
public int docID() {
return iterator.docID();
}
@Override
public DocIdSetIterator iterator() {
return iterator;
}
@Override
public float getMaxScore(int upTo) {
return Float.POSITIVE_INFINITY;
}
@Override
public float score() {
return cachedScore[0];
}
};
}
@Override
public DocIdSetIterator iterator() {
return iterator;
}
@Override
public float getMaxScore(int upTo) {
return Float.POSITIVE_INFINITY;
}
@Override
public float score() {
return cachedScore[0];
public long cost() {
return iterator.cost();
}
}
}

View File

@ -23,6 +23,7 @@ import java.util.Objects;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.util.Bits;
/**
@ -106,10 +107,13 @@ public class ByteVectorSimilarityQuery extends AbstractVectorSimilarityQuery {
@Override
@SuppressWarnings("resource")
protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitLimit)
protected TopDocs approximateSearch(
LeafReaderContext context,
Bits acceptDocs,
int visitLimit,
KnnCollectorManager knnCollectorManager)
throws IOException {
KnnCollector collector =
new VectorSimilarityCollector(traversalSimilarity, resultSimilarity, visitLimit);
KnnCollector collector = knnCollectorManager.newCollector(visitLimit, context);
context.reader().searchNearestVectors(field, target, collector, acceptDocs);
return collector.topDocs();
}

View File

@ -23,6 +23,7 @@ import java.util.Objects;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.VectorUtil;
@ -108,10 +109,13 @@ public class FloatVectorSimilarityQuery extends AbstractVectorSimilarityQuery {
@Override
@SuppressWarnings("resource")
protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitLimit)
protected TopDocs approximateSearch(
LeafReaderContext context,
Bits acceptDocs,
int visitLimit,
KnnCollectorManager knnCollectorManager)
throws IOException {
KnnCollector collector =
new VectorSimilarityCollector(traversalSimilarity, resultSimilarity, visitLimit);
KnnCollector collector = knnCollectorManager.newCollector(visitLimit, context);
context.reader().searchNearestVectors(field, target, collector, acceptDocs);
return collector.topDocs();
}

View File

@ -16,6 +16,8 @@
*/
package org.apache.lucene.search;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
@ -32,6 +34,8 @@ import org.apache.lucene.document.IntField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
@ -475,6 +479,62 @@ abstract class BaseVectorSimilarityQueryTestCase<
}
}
/** Test that the query times out correctly. */
public void testTimeout() throws IOException {
V[] vectors = getRandomVectors(numDocs, dim);
V queryVector = getRandomVector(dim);
try (Directory indexStore = getIndexStore(vectors);
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
// This query is cacheable, explicitly prevent it
searcher.setQueryCache(null);
Query query =
new CountingQuery(
getVectorQuery(
vectorField,
queryVector,
Float.NEGATIVE_INFINITY,
Float.NEGATIVE_INFINITY,
null));
assertEquals(numDocs, searcher.count(query)); // Expect some results without timeout
searcher.setTimeout(() -> true); // Immediately timeout
assertEquals(0, searcher.count(query)); // Expect no results with the timeout
searcher.setTimeout(new CountingQueryTimeout(numDocs - 1)); // Do not score all docs
int count = searcher.count(query);
assertTrue(
"0 < count=" + count + " < numDocs=" + numDocs,
count > 0 && count < numDocs); // Expect partial results
// Test timeout with filter
int numFiltered = random().nextInt(numDocs / 2, numDocs);
Query filter = IntField.newSetQuery(idField, getFiltered(numFiltered));
Query filteredQuery =
new CountingQuery(
getVectorQuery(
vectorField,
queryVector,
Float.NEGATIVE_INFINITY,
Float.NEGATIVE_INFINITY,
filter));
searcher.setTimeout(() -> false); // Set a timeout which is never met
assertEquals(numFiltered, searcher.count(filteredQuery));
searcher.setTimeout(
new CountingQueryTimeout(numFiltered - 1)); // Timeout before scoring all filtered docs
int filteredCount = searcher.count(filteredQuery);
assertTrue(
"0 < filteredCount=" + filteredCount + " < numFiltered=" + numFiltered,
filteredCount > 0 && filteredCount < numFiltered); // Expect partial results
}
}
private float getSimilarity(V[] vectors, V queryVector, int targetVisited) {
assertTrue(targetVisited >= 0 && targetVisited <= numDocs);
if (targetVisited == 0) {
@ -526,4 +586,94 @@ abstract class BaseVectorSimilarityQueryTestCase<
}
return dir;
}
private static class CountingQueryTimeout implements QueryTimeout {
private int remaining;
public CountingQueryTimeout(int count) {
remaining = count;
}
@Override
public boolean shouldExit() {
if (remaining > 0) {
remaining--;
return false;
}
return true;
}
}
/**
* A {@link Query} that emulates {@link Weight#count(LeafReaderContext)} by counting number of
* docs of underlying {@link Scorer#iterator()}. TODO: This is a workaround to count partial
* results of {@link #delegate} because {@link TimeLimitingBulkScorer} immediately discards
* results after timeout.
*/
private static class CountingQuery extends Query {
private final Query delegate;
private CountingQuery(Query delegate) {
this.delegate = delegate;
}
@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
throws IOException {
return new Weight(this) {
final Weight delegateWeight = delegate.createWeight(searcher, scoreMode, boost);
@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
return delegateWeight.explain(context, doc);
}
@Override
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
return delegateWeight.scorerSupplier(context);
}
@Override
public int count(LeafReaderContext context) throws IOException {
Scorer scorer = scorer(context);
if (scorer == null) {
return 0;
}
int count = 0;
DocIdSetIterator iterator = scorer.iterator();
while (iterator.nextDoc() != NO_MORE_DOCS) {
count++;
}
return count;
}
@Override
public boolean isCacheable(LeafReaderContext ctx) {
return delegateWeight.isCacheable(ctx);
}
};
}
@Override
public String toString(String field) {
return String.format(
Locale.ROOT, "%s[%s]", getClass().getSimpleName(), delegate.toString(field));
}
@Override
public void visit(QueryVisitor visitor) {
visitor.visitLeaf(this);
}
@Override
public boolean equals(Object obj) {
return sameClassAs(obj) && delegate.equals(((CountingQuery) obj).delegate);
}
@Override
public int hashCode() {
return delegate.hashCode();
}
}
}