Add timeout support to AbstractKnnVectorQuery (#13202)

* Add timeout support to graph searches in AbstractKnnVectorQuery

* Also timeout exact searches

* Return partial KNN results

* Add tests for partial KNN results

- Refactor tests to base classes
- Also timeout exact searches in Lucene99HnswVectorsReader

* Add CHANGES.txt entry and fix some comments

---------

Co-authored-by: Kaival Parikh <kaivalp2000@gmail.com>
This commit is contained in:
Kaival Parikh 2024-04-03 19:23:24 +05:30 committed by GitHub
parent 45c4f8c052
commit df154cdc22
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 263 additions and 17 deletions

View File

@ -241,6 +241,9 @@ Improvements
implementation is the ConcurrentMergeScheduler and the Lucene99HnswVectorsFormat will use it if no other
executor is provided. (Ben Trent)
* GITHUB#13202: Early terminate graph and exact searches of AbstractKnnVectorQuery to follow timeout set from
IndexSearcher#setTimeout(QueryTimeout). (Kaival Parikh)
Optimizations
---------------------

View File

@ -260,6 +260,9 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
// and collect them
for (int i = 0; i < scorer.maxOrd(); i++) {
if (acceptedOrds == null || acceptedOrds.get(i)) {
if (knnCollector.earlyTerminated()) {
break;
}
knnCollector.incVisitedCount(1);
knnCollector.collect(scorer.ordToDoc(i), scorer.score(i));
}
@ -288,6 +291,9 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
// and collect them
for (int i = 0; i < scorer.maxOrd(); i++) {
if (acceptedOrds == null || acceptedOrds.get(i)) {
if (knnCollector.earlyTerminated()) {
break;
}
knnCollector.incVisitedCount(1);
knnCollector.collect(scorer.ordToDoc(i), scorer.score(i));
}

View File

@ -29,6 +29,7 @@ import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.search.knn.TopKnnCollectorManager;
import org.apache.lucene.util.BitSet;
@ -81,7 +82,9 @@ abstract class AbstractKnnVectorQuery extends Query {
filterWeight = null;
}
KnnCollectorManager knnCollectorManager = getKnnCollectorManager(k, indexSearcher);
TimeLimitingKnnCollectorManager knnCollectorManager =
new TimeLimitingKnnCollectorManager(
getKnnCollectorManager(k, indexSearcher), indexSearcher.getTimeout());
TaskExecutor taskExecutor = indexSearcher.getTaskExecutor();
List<LeafReaderContext> leafReaderContexts = reader.leaves();
List<Callable<TopDocs>> tasks = new ArrayList<>(leafReaderContexts.size());
@ -99,9 +102,11 @@ abstract class AbstractKnnVectorQuery extends Query {
}
private TopDocs searchLeaf(
LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager)
LeafReaderContext ctx,
Weight filterWeight,
TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager)
throws IOException {
TopDocs results = getLeafResults(ctx, filterWeight, knnCollectorManager);
TopDocs results = getLeafResults(ctx, filterWeight, timeLimitingKnnCollectorManager);
if (ctx.docBase > 0) {
for (ScoreDoc scoreDoc : results.scoreDocs) {
scoreDoc.doc += ctx.docBase;
@ -111,13 +116,15 @@ abstract class AbstractKnnVectorQuery extends Query {
}
private TopDocs getLeafResults(
LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager)
LeafReaderContext ctx,
Weight filterWeight,
TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager)
throws IOException {
Bits liveDocs = ctx.reader().getLiveDocs();
int maxDoc = ctx.reader().maxDoc();
if (filterWeight == null) {
return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, knnCollectorManager);
return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, timeLimitingKnnCollectorManager);
}
Scorer scorer = filterWeight.scorer(ctx);
@ -127,21 +134,24 @@ abstract class AbstractKnnVectorQuery extends Query {
BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, maxDoc);
final int cost = acceptDocs.cardinality();
QueryTimeout queryTimeout = timeLimitingKnnCollectorManager.getQueryTimeout();
if (cost <= k) {
// If there are <= k possible matches, short-circuit and perform exact search, since HNSW
// must always visit at least k documents
return exactSearch(ctx, new BitSetIterator(acceptDocs, cost));
return exactSearch(ctx, new BitSetIterator(acceptDocs, cost), queryTimeout);
}
// Perform the approximate kNN search
// We pass cost + 1 here to account for the edge case when we explore exactly cost vectors
TopDocs results = approximateSearch(ctx, acceptDocs, cost + 1, knnCollectorManager);
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) {
TopDocs results = approximateSearch(ctx, acceptDocs, cost + 1, timeLimitingKnnCollectorManager);
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO
// Return partial results only when timeout is met
|| (queryTimeout != null && queryTimeout.shouldExit())) {
return results;
} else {
// We stopped the kNN search because it visited too many nodes, so fall back to exact search
return exactSearch(ctx, new BitSetIterator(acceptDocs, cost));
return exactSearch(ctx, new BitSetIterator(acceptDocs, cost), queryTimeout);
}
}
@ -178,7 +188,8 @@ abstract class AbstractKnnVectorQuery extends Query {
throws IOException;
// We allow this to be overridden so that tests can check what search strategy is used
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
protected TopDocs exactSearch(
LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout)
throws IOException {
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
@ -192,9 +203,16 @@ abstract class AbstractKnnVectorQuery extends Query {
}
final int queueSize = Math.min(k, Math.toIntExact(acceptIterator.cost()));
HitQueue queue = new HitQueue(queueSize, true);
TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO;
ScoreDoc topDoc = queue.top();
int doc;
while ((doc = acceptIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
// Mark results as partial if timeout is met
if (queryTimeout != null && queryTimeout.shouldExit()) {
relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
break;
}
boolean advanced = vectorScorer.advanceExact(doc);
assert advanced;
@ -216,7 +234,7 @@ abstract class AbstractKnnVectorQuery extends Query {
topScoreDocs[i] = queue.pop();
}
TotalHits totalHits = new TotalHits(acceptIterator.cost(), TotalHits.Relation.EQUAL_TO);
TotalHits totalHits = new TotalHits(acceptIterator.cost(), relation);
return new TopDocs(totalHits, topScoreDocs);
}

View File

@ -483,6 +483,14 @@ public class IndexSearcher {
return search(query, manager);
}
/**
* Get the configured {@link QueryTimeout} for all searches that run through this {@link
* IndexSearcher}, or {@code null} if not set.
*/
public QueryTimeout getTimeout() {
return this.queryTimeout;
}
/** Set a {@link QueryTimeout} for all searches that run through this {@link IndexSearcher}. */
public void setTimeout(QueryTimeout queryTimeout) {
this.queryTimeout = queryTimeout;

View File

@ -0,0 +1,95 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.search.knn.KnnCollectorManager;
/** A {@link KnnCollectorManager} that collects results with a timeout. */
public class TimeLimitingKnnCollectorManager implements KnnCollectorManager {
private final KnnCollectorManager delegate;
private final QueryTimeout queryTimeout;
public TimeLimitingKnnCollectorManager(KnnCollectorManager delegate, QueryTimeout timeout) {
this.delegate = delegate;
this.queryTimeout = timeout;
}
/** Get the configured {@link QueryTimeout} for terminating graph and exact searches. */
public QueryTimeout getQueryTimeout() {
return queryTimeout;
}
@Override
public KnnCollector newCollector(int visitedLimit, LeafReaderContext context) throws IOException {
KnnCollector collector = delegate.newCollector(visitedLimit, context);
if (queryTimeout == null) {
return collector;
}
return new KnnCollector() {
@Override
public boolean earlyTerminated() {
return queryTimeout.shouldExit() || collector.earlyTerminated();
}
@Override
public void incVisitedCount(int count) {
collector.incVisitedCount(count);
}
@Override
public long visitedCount() {
return collector.visitedCount();
}
@Override
public long visitLimit() {
return collector.visitLimit();
}
@Override
public int k() {
return collector.k();
}
@Override
public boolean collect(int docId, float similarity) {
return collector.collect(docId, similarity);
}
@Override
public float minCompetitiveSimilarity() {
return collector.minCompetitiveSimilarity();
}
@Override
public TopDocs topDocs() {
TopDocs docs = collector.topDocs();
// Mark results as partial if timeout is met
TotalHits.Relation relation =
queryTimeout.shouldExit()
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: docs.totalHits.relation;
return new TopDocs(new TotalHits(docs.totalHits.value, relation), docs.scoreDocs);
}
};
}
}

View File

@ -39,6 +39,7 @@ import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.index.StoredFields;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.VectorSimilarityFunction;
@ -765,6 +766,34 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
}
}
/** Test that the query times out correctly. */
public void testTimeout() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 2);
AbstractKnnVectorQuery exactQuery =
getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 10, new MatchAllDocsQuery());
assertEquals(2, searcher.count(query)); // Expect some results without timeout
assertEquals(3, searcher.count(exactQuery)); // Same for exact search
searcher.setTimeout(() -> true); // Immediately timeout
assertEquals(0, searcher.count(query)); // Expect no results with the timeout
assertEquals(0, searcher.count(exactQuery)); // Same for exact search
searcher.setTimeout(new CountingQueryTimeout(1)); // Only score 1 doc
// Note: This depends on the HNSW graph having just one layer,
// would be 0 in case of multiple layers
assertEquals(1, searcher.count(query)); // Expect only 1 result
searcher.setTimeout(new CountingQueryTimeout(1)); // Only score 1 doc
assertEquals(1, searcher.count(exactQuery)); // Expect only 1 result
}
}
/** Creates a new directory and adds documents with the given vectors as kNN vector fields */
Directory getIndexStore(String field, float[]... contents) throws IOException {
return getIndexStore(field, VectorSimilarityFunction.EUCLIDEAN, contents);
@ -1006,4 +1035,21 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
}
}
}
private static class CountingQueryTimeout implements QueryTimeout {
private int remaining;
public CountingQueryTimeout(int count) {
remaining = count;
}
@Override
public boolean shouldExit() {
if (remaining > 0) {
remaining--;
return false;
}
return true;
}
}
}

View File

@ -22,6 +22,7 @@ import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.TestVectorUtil;
@ -109,7 +110,8 @@ public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
}
@Override
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) {
protected TopDocs exactSearch(
LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) {
throw new UnsupportedOperationException("exact search is not supported");
}

View File

@ -34,6 +34,7 @@ import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
@ -258,7 +259,8 @@ public class TestKnnFloatVectorQuery extends BaseKnnVectorQueryTestCase {
}
@Override
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) {
protected TopDocs exactSearch(
LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) {
throw new UnsupportedOperationException("exact search is not supported");
}

View File

@ -22,6 +22,7 @@ import java.util.Objects;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.HitQueue;
@ -77,7 +78,8 @@ public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery {
}
@Override
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
protected TopDocs exactSearch(
LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout)
throws IOException {
ByteVectorValues byteVectorValues = context.reader().getByteVectorValues(field);
if (byteVectorValues == null) {
@ -100,8 +102,15 @@ public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery {
fi.getVectorSimilarityFunction());
final int queueSize = Math.min(k, Math.toIntExact(acceptIterator.cost()));
HitQueue queue = new HitQueue(queueSize, true);
TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO;
ScoreDoc topDoc = queue.top();
while (vectorScorer.nextParent() != DocIdSetIterator.NO_MORE_DOCS) {
// Mark results as partial if timeout is met
if (queryTimeout != null && queryTimeout.shouldExit()) {
relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
break;
}
float score = vectorScorer.score();
if (score > topDoc.score) {
topDoc.score = score;
@ -120,7 +129,7 @@ public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery {
topScoreDocs[i] = queue.pop();
}
TotalHits totalHits = new TotalHits(acceptIterator.cost(), TotalHits.Relation.EQUAL_TO);
TotalHits totalHits = new TotalHits(acceptIterator.cost(), relation);
return new TopDocs(totalHits, topScoreDocs);
}

View File

@ -22,6 +22,7 @@ import java.util.Objects;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.HitQueue;
@ -77,7 +78,8 @@ public class DiversifyingChildrenFloatKnnVectorQuery extends KnnFloatVectorQuery
}
@Override
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
protected TopDocs exactSearch(
LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout)
throws IOException {
FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(field);
if (floatVectorValues == null) {
@ -100,8 +102,15 @@ public class DiversifyingChildrenFloatKnnVectorQuery extends KnnFloatVectorQuery
fi.getVectorSimilarityFunction());
final int queueSize = Math.min(k, Math.toIntExact(acceptIterator.cost()));
HitQueue queue = new HitQueue(queueSize, true);
TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO;
ScoreDoc topDoc = queue.top();
while (vectorScorer.nextParent() != DocIdSetIterator.NO_MORE_DOCS) {
// Mark results as partial if timeout is met
if (queryTimeout != null && queryTimeout.shouldExit()) {
relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
break;
}
float score = vectorScorer.score();
if (score > topDoc.score) {
topDoc.score = score;
@ -120,7 +129,7 @@ public class DiversifyingChildrenFloatKnnVectorQuery extends KnnFloatVectorQuery
topScoreDocs[i] = queue.pop();
}
TotalHits totalHits = new TotalHits(acceptIterator.cost(), TotalHits.Relation.EQUAL_TO);
TotalHits totalHits = new TotalHits(acceptIterator.cost(), relation);
return new TopDocs(totalHits, topScoreDocs);
}

View File

@ -33,6 +33,7 @@ import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
@ -283,6 +284,36 @@ abstract class ParentBlockJoinKnnVectorQueryTestCase extends LuceneTestCase {
}
}
/** Test that the query times out correctly. */
public void testTimeout() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
BitSetProducer parentFilter = parentFilter(reader);
IndexSearcher searcher = newSearcher(reader);
Query query = getParentJoinKnnQuery("field", new float[] {1, 2}, null, 2, parentFilter);
Query exactQuery =
getParentJoinKnnQuery(
"field", new float[] {1, 2}, new MatchAllDocsQuery(), 10, parentFilter);
assertEquals(2, searcher.count(query)); // Expect some results without timeout
assertEquals(3, searcher.count(exactQuery)); // Same for exact search
searcher.setTimeout(() -> true); // Immediately timeout
assertEquals(0, searcher.count(query)); // Expect no results with the timeout
assertEquals(0, searcher.count(exactQuery)); // Same for exact search
searcher.setTimeout(new CountingQueryTimeout(1)); // Only score 1 parent
// Note: This depends on the HNSW graph having just one layer,
// would be 0 in case of multiple layers
assertEquals(1, searcher.count(query)); // Expect only 1 result
searcher.setTimeout(new CountingQueryTimeout(1)); // Only score 1 parent
assertEquals(1, searcher.count(exactQuery)); // Expect only 1 result
}
}
Directory getIndexStore(String field, float[]... contents) throws IOException {
Directory indexStore = newDirectory();
RandomIndexWriter writer =
@ -352,4 +383,21 @@ abstract class ParentBlockJoinKnnVectorQueryTestCase extends LuceneTestCase {
assertEquals(idToScore.get(actualId), scorer.score(), 0.0001);
}
}
private static class CountingQueryTimeout implements QueryTimeout {
private int remaining;
public CountingQueryTimeout(int count) {
remaining = count;
}
@Override
public boolean shouldExit() {
if (remaining > 0) {
remaining--;
return false;
}
return true;
}
}
}