mirror of https://github.com/apache/lucene.git
Add timeout support to AbstractVectorSimilarityQuery (#13285)
Co-authored-by: Kaival Parikh <kaivalnp@amazon.com>
This commit is contained in:
parent
43c80117dd
commit
e0e5d81df8
|
@ -288,6 +288,9 @@ Improvements
|
|||
|
||||
* GITHUB#13625: Remove BitSet#nextSetBit code duplication. (Greg Miller)
|
||||
|
||||
* GITHUB#13285: Early terminate graph searches of AbstractVectorSimilarityQuery to follow timeout set from
|
||||
IndexSearcher#setTimeout(QueryTimeout). (Kaival Parikh)
|
||||
|
||||
Optimizations
|
||||
---------------------
|
||||
|
||||
|
|
|
@ -23,6 +23,8 @@ import java.util.List;
|
|||
import java.util.Objects;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.QueryTimeout;
|
||||
import org.apache.lucene.search.knn.KnnCollectorManager;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.BitSetIterator;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
@ -58,10 +60,19 @@ abstract class AbstractVectorSimilarityQuery extends Query {
|
|||
this.filter = filter;
|
||||
}
|
||||
|
||||
protected KnnCollectorManager getKnnCollectorManager() {
|
||||
return (visitedLimit, context) ->
|
||||
new VectorSimilarityCollector(traversalSimilarity, resultSimilarity, visitedLimit);
|
||||
}
|
||||
|
||||
abstract VectorScorer createVectorScorer(LeafReaderContext context) throws IOException;
|
||||
|
||||
protected abstract TopDocs approximateSearch(
|
||||
LeafReaderContext context, Bits acceptDocs, int visitLimit) throws IOException;
|
||||
LeafReaderContext context,
|
||||
Bits acceptDocs,
|
||||
int visitLimit,
|
||||
KnnCollectorManager knnCollectorManager)
|
||||
throws IOException;
|
||||
|
||||
@Override
|
||||
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
|
||||
|
@ -72,6 +83,10 @@ abstract class AbstractVectorSimilarityQuery extends Query {
|
|||
? null
|
||||
: searcher.createWeight(searcher.rewrite(filter), ScoreMode.COMPLETE_NO_SCORES, 1);
|
||||
|
||||
final QueryTimeout queryTimeout = searcher.getTimeout();
|
||||
final TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager =
|
||||
new TimeLimitingKnnCollectorManager(getKnnCollectorManager(), queryTimeout);
|
||||
|
||||
@Override
|
||||
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
|
||||
if (filterWeight != null) {
|
||||
|
@ -103,16 +118,14 @@ abstract class AbstractVectorSimilarityQuery extends Query {
|
|||
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
|
||||
LeafReader leafReader = context.reader();
|
||||
Bits liveDocs = leafReader.getLiveDocs();
|
||||
final Scorer vectorSimilarityScorer;
|
||||
|
||||
// If there is no filter
|
||||
if (filterWeight == null) {
|
||||
// Return exhaustive results
|
||||
TopDocs results = approximateSearch(context, liveDocs, Integer.MAX_VALUE);
|
||||
if (results.scoreDocs.length == 0) {
|
||||
return null;
|
||||
}
|
||||
vectorSimilarityScorer =
|
||||
VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs);
|
||||
TopDocs results =
|
||||
approximateSearch(
|
||||
context, liveDocs, Integer.MAX_VALUE, timeLimitingKnnCollectorManager);
|
||||
return VectorSimilarityScorerSupplier.fromScoreDocs(boost, results.scoreDocs);
|
||||
} else {
|
||||
Scorer scorer = filterWeight.scorer(context);
|
||||
if (scorer == null) {
|
||||
|
@ -143,27 +156,23 @@ abstract class AbstractVectorSimilarityQuery extends Query {
|
|||
}
|
||||
|
||||
// Perform an approximate search
|
||||
TopDocs results = approximateSearch(context, acceptDocs, cardinality);
|
||||
TopDocs results =
|
||||
approximateSearch(context, acceptDocs, cardinality, timeLimitingKnnCollectorManager);
|
||||
|
||||
// If the limit was exhausted
|
||||
if (results.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO) {
|
||||
// Return a lazy-loading iterator
|
||||
vectorSimilarityScorer =
|
||||
VectorSimilarityScorer.fromAcceptDocs(
|
||||
this,
|
||||
boost,
|
||||
createVectorScorer(context),
|
||||
new BitSetIterator(acceptDocs, cardinality),
|
||||
resultSimilarity);
|
||||
} else if (results.scoreDocs.length == 0) {
|
||||
return null;
|
||||
} else {
|
||||
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO
|
||||
// Return partial results only when timeout is met
|
||||
|| (queryTimeout != null && queryTimeout.shouldExit())) {
|
||||
// Return an iterator over the collected results
|
||||
vectorSimilarityScorer =
|
||||
VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs);
|
||||
return VectorSimilarityScorerSupplier.fromScoreDocs(boost, results.scoreDocs);
|
||||
} else {
|
||||
// Return a lazy-loading iterator
|
||||
return VectorSimilarityScorerSupplier.fromAcceptDocs(
|
||||
boost,
|
||||
createVectorScorer(context),
|
||||
new BitSetIterator(acceptDocs, cardinality),
|
||||
resultSimilarity);
|
||||
}
|
||||
}
|
||||
return new DefaultScorerSupplier(vectorSimilarityScorer);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -197,16 +206,20 @@ abstract class AbstractVectorSimilarityQuery extends Query {
|
|||
return Objects.hash(field, traversalSimilarity, resultSimilarity, filter);
|
||||
}
|
||||
|
||||
private static class VectorSimilarityScorer extends Scorer {
|
||||
private static class VectorSimilarityScorerSupplier extends ScorerSupplier {
|
||||
final DocIdSetIterator iterator;
|
||||
final float[] cachedScore;
|
||||
|
||||
VectorSimilarityScorer(DocIdSetIterator iterator, float[] cachedScore) {
|
||||
VectorSimilarityScorerSupplier(DocIdSetIterator iterator, float[] cachedScore) {
|
||||
this.iterator = iterator;
|
||||
this.cachedScore = cachedScore;
|
||||
}
|
||||
|
||||
static VectorSimilarityScorer fromScoreDocs(Weight weight, float boost, ScoreDoc[] scoreDocs) {
|
||||
static VectorSimilarityScorerSupplier fromScoreDocs(float boost, ScoreDoc[] scoreDocs) {
|
||||
if (scoreDocs.length == 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Sort in ascending order of docid
|
||||
Arrays.sort(scoreDocs, Comparator.comparingInt(scoreDoc -> scoreDoc.doc));
|
||||
|
||||
|
@ -252,18 +265,15 @@ abstract class AbstractVectorSimilarityQuery extends Query {
|
|||
}
|
||||
};
|
||||
|
||||
return new VectorSimilarityScorer(iterator, cachedScore);
|
||||
return new VectorSimilarityScorerSupplier(iterator, cachedScore);
|
||||
}
|
||||
|
||||
static VectorSimilarityScorer fromAcceptDocs(
|
||||
Weight weight,
|
||||
float boost,
|
||||
VectorScorer scorer,
|
||||
DocIdSetIterator acceptDocs,
|
||||
float threshold) {
|
||||
static VectorSimilarityScorerSupplier fromAcceptDocs(
|
||||
float boost, VectorScorer scorer, DocIdSetIterator acceptDocs, float threshold) {
|
||||
if (scorer == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
float[] cachedScore = new float[1];
|
||||
DocIdSetIterator vectorIterator = scorer.iterator();
|
||||
DocIdSetIterator conjunction =
|
||||
|
@ -281,27 +291,37 @@ abstract class AbstractVectorSimilarityQuery extends Query {
|
|||
}
|
||||
};
|
||||
|
||||
return new VectorSimilarityScorer(iterator, cachedScore);
|
||||
return new VectorSimilarityScorerSupplier(iterator, cachedScore);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return iterator.docID();
|
||||
public Scorer get(long leadCost) {
|
||||
return new Scorer() {
|
||||
@Override
|
||||
public int docID() {
|
||||
return iterator.docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return iterator;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float getMaxScore(int upTo) {
|
||||
return Float.POSITIVE_INFINITY;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score() {
|
||||
return cachedScore[0];
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return iterator;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float getMaxScore(int upTo) {
|
||||
return Float.POSITIVE_INFINITY;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score() {
|
||||
return cachedScore[0];
|
||||
public long cost() {
|
||||
return iterator.cost();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ import java.util.Objects;
|
|||
import org.apache.lucene.document.KnnByteVectorField;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.search.knn.KnnCollectorManager;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
||||
/**
|
||||
|
@ -106,10 +107,13 @@ public class ByteVectorSimilarityQuery extends AbstractVectorSimilarityQuery {
|
|||
|
||||
@Override
|
||||
@SuppressWarnings("resource")
|
||||
protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitLimit)
|
||||
protected TopDocs approximateSearch(
|
||||
LeafReaderContext context,
|
||||
Bits acceptDocs,
|
||||
int visitLimit,
|
||||
KnnCollectorManager knnCollectorManager)
|
||||
throws IOException {
|
||||
KnnCollector collector =
|
||||
new VectorSimilarityCollector(traversalSimilarity, resultSimilarity, visitLimit);
|
||||
KnnCollector collector = knnCollectorManager.newCollector(visitLimit, context);
|
||||
context.reader().searchNearestVectors(field, target, collector, acceptDocs);
|
||||
return collector.topDocs();
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ import java.util.Objects;
|
|||
import org.apache.lucene.document.KnnFloatVectorField;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.search.knn.KnnCollectorManager;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
|
||||
|
@ -108,10 +109,13 @@ public class FloatVectorSimilarityQuery extends AbstractVectorSimilarityQuery {
|
|||
|
||||
@Override
|
||||
@SuppressWarnings("resource")
|
||||
protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitLimit)
|
||||
protected TopDocs approximateSearch(
|
||||
LeafReaderContext context,
|
||||
Bits acceptDocs,
|
||||
int visitLimit,
|
||||
KnnCollectorManager knnCollectorManager)
|
||||
throws IOException {
|
||||
KnnCollector collector =
|
||||
new VectorSimilarityCollector(traversalSimilarity, resultSimilarity, visitLimit);
|
||||
KnnCollector collector = knnCollectorManager.newCollector(visitLimit, context);
|
||||
context.reader().searchNearestVectors(field, target, collector, acceptDocs);
|
||||
return collector.topDocs();
|
||||
}
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
*/
|
||||
package org.apache.lucene.search;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
|
@ -32,6 +34,8 @@ import org.apache.lucene.document.IntField;
|
|||
import org.apache.lucene.index.DirectoryReader;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.QueryTimeout;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.store.Directory;
|
||||
|
@ -475,6 +479,62 @@ abstract class BaseVectorSimilarityQueryTestCase<
|
|||
}
|
||||
}
|
||||
|
||||
/** Test that the query times out correctly. */
|
||||
public void testTimeout() throws IOException {
|
||||
V[] vectors = getRandomVectors(numDocs, dim);
|
||||
V queryVector = getRandomVector(dim);
|
||||
|
||||
try (Directory indexStore = getIndexStore(vectors);
|
||||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
|
||||
// This query is cacheable, explicitly prevent it
|
||||
searcher.setQueryCache(null);
|
||||
|
||||
Query query =
|
||||
new CountingQuery(
|
||||
getVectorQuery(
|
||||
vectorField,
|
||||
queryVector,
|
||||
Float.NEGATIVE_INFINITY,
|
||||
Float.NEGATIVE_INFINITY,
|
||||
null));
|
||||
|
||||
assertEquals(numDocs, searcher.count(query)); // Expect some results without timeout
|
||||
|
||||
searcher.setTimeout(() -> true); // Immediately timeout
|
||||
assertEquals(0, searcher.count(query)); // Expect no results with the timeout
|
||||
|
||||
searcher.setTimeout(new CountingQueryTimeout(numDocs - 1)); // Do not score all docs
|
||||
int count = searcher.count(query);
|
||||
assertTrue(
|
||||
"0 < count=" + count + " < numDocs=" + numDocs,
|
||||
count > 0 && count < numDocs); // Expect partial results
|
||||
|
||||
// Test timeout with filter
|
||||
int numFiltered = random().nextInt(numDocs / 2, numDocs);
|
||||
Query filter = IntField.newSetQuery(idField, getFiltered(numFiltered));
|
||||
Query filteredQuery =
|
||||
new CountingQuery(
|
||||
getVectorQuery(
|
||||
vectorField,
|
||||
queryVector,
|
||||
Float.NEGATIVE_INFINITY,
|
||||
Float.NEGATIVE_INFINITY,
|
||||
filter));
|
||||
|
||||
searcher.setTimeout(() -> false); // Set a timeout which is never met
|
||||
assertEquals(numFiltered, searcher.count(filteredQuery));
|
||||
|
||||
searcher.setTimeout(
|
||||
new CountingQueryTimeout(numFiltered - 1)); // Timeout before scoring all filtered docs
|
||||
int filteredCount = searcher.count(filteredQuery);
|
||||
assertTrue(
|
||||
"0 < filteredCount=" + filteredCount + " < numFiltered=" + numFiltered,
|
||||
filteredCount > 0 && filteredCount < numFiltered); // Expect partial results
|
||||
}
|
||||
}
|
||||
|
||||
private float getSimilarity(V[] vectors, V queryVector, int targetVisited) {
|
||||
assertTrue(targetVisited >= 0 && targetVisited <= numDocs);
|
||||
if (targetVisited == 0) {
|
||||
|
@ -526,4 +586,94 @@ abstract class BaseVectorSimilarityQueryTestCase<
|
|||
}
|
||||
return dir;
|
||||
}
|
||||
|
||||
private static class CountingQueryTimeout implements QueryTimeout {
|
||||
private int remaining;
|
||||
|
||||
public CountingQueryTimeout(int count) {
|
||||
remaining = count;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean shouldExit() {
|
||||
if (remaining > 0) {
|
||||
remaining--;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A {@link Query} that emulates {@link Weight#count(LeafReaderContext)} by counting number of
|
||||
* docs of underlying {@link Scorer#iterator()}. TODO: This is a workaround to count partial
|
||||
* results of {@link #delegate} because {@link TimeLimitingBulkScorer} immediately discards
|
||||
* results after timeout.
|
||||
*/
|
||||
private static class CountingQuery extends Query {
|
||||
private final Query delegate;
|
||||
|
||||
private CountingQuery(Query delegate) {
|
||||
this.delegate = delegate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
|
||||
throws IOException {
|
||||
return new Weight(this) {
|
||||
final Weight delegateWeight = delegate.createWeight(searcher, scoreMode, boost);
|
||||
|
||||
@Override
|
||||
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
|
||||
return delegateWeight.explain(context, doc);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
|
||||
return delegateWeight.scorerSupplier(context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int count(LeafReaderContext context) throws IOException {
|
||||
Scorer scorer = scorer(context);
|
||||
if (scorer == null) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int count = 0;
|
||||
DocIdSetIterator iterator = scorer.iterator();
|
||||
while (iterator.nextDoc() != NO_MORE_DOCS) {
|
||||
count++;
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isCacheable(LeafReaderContext ctx) {
|
||||
return delegateWeight.isCacheable(ctx);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString(String field) {
|
||||
return String.format(
|
||||
Locale.ROOT, "%s[%s]", getClass().getSimpleName(), delegate.toString(field));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(QueryVisitor visitor) {
|
||||
visitor.visitLeaf(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
return sameClassAs(obj) && delegate.equals(((CountingQuery) obj).delegate);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return delegate.hashCode();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue