mirror of https://github.com/apache/lucene.git
Add timeout support to AbstractKnnVectorQuery (#13202)
* Add timeout support to graph searches in AbstractKnnVectorQuery * Also timeout exact searches * Return partial KNN results * Add tests for partial KNN results - Refactor tests to base classes - Also timeout exact searches in Lucene99HnswVectorsReader * Add CHANGES.txt entry and fix some comments --------- Co-authored-by: Kaival Parikh <kaivalp2000@gmail.com>
This commit is contained in:
parent
45c4f8c052
commit
df154cdc22
|
@ -241,6 +241,9 @@ Improvements
|
|||
implementation is the ConcurrentMergeScheduler and the Lucene99HnswVectorsFormat will use it if no other
|
||||
executor is provided. (Ben Trent)
|
||||
|
||||
* GITHUB#13202: Early terminate graph and exact searches of AbstractKnnVectorQuery to follow timeout set from
|
||||
IndexSearcher#setTimeout(QueryTimeout). (Kaival Parikh)
|
||||
|
||||
Optimizations
|
||||
---------------------
|
||||
|
||||
|
|
|
@ -260,6 +260,9 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
|
|||
// and collect them
|
||||
for (int i = 0; i < scorer.maxOrd(); i++) {
|
||||
if (acceptedOrds == null || acceptedOrds.get(i)) {
|
||||
if (knnCollector.earlyTerminated()) {
|
||||
break;
|
||||
}
|
||||
knnCollector.incVisitedCount(1);
|
||||
knnCollector.collect(scorer.ordToDoc(i), scorer.score(i));
|
||||
}
|
||||
|
@ -288,6 +291,9 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
|
|||
// and collect them
|
||||
for (int i = 0; i < scorer.maxOrd(); i++) {
|
||||
if (acceptedOrds == null || acceptedOrds.get(i)) {
|
||||
if (knnCollector.earlyTerminated()) {
|
||||
break;
|
||||
}
|
||||
knnCollector.incVisitedCount(1);
|
||||
knnCollector.collect(scorer.ordToDoc(i), scorer.score(i));
|
||||
}
|
||||
|
|
|
@ -29,6 +29,7 @@ import org.apache.lucene.codecs.KnnVectorsReader;
|
|||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.QueryTimeout;
|
||||
import org.apache.lucene.search.knn.KnnCollectorManager;
|
||||
import org.apache.lucene.search.knn.TopKnnCollectorManager;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
|
@ -81,7 +82,9 @@ abstract class AbstractKnnVectorQuery extends Query {
|
|||
filterWeight = null;
|
||||
}
|
||||
|
||||
KnnCollectorManager knnCollectorManager = getKnnCollectorManager(k, indexSearcher);
|
||||
TimeLimitingKnnCollectorManager knnCollectorManager =
|
||||
new TimeLimitingKnnCollectorManager(
|
||||
getKnnCollectorManager(k, indexSearcher), indexSearcher.getTimeout());
|
||||
TaskExecutor taskExecutor = indexSearcher.getTaskExecutor();
|
||||
List<LeafReaderContext> leafReaderContexts = reader.leaves();
|
||||
List<Callable<TopDocs>> tasks = new ArrayList<>(leafReaderContexts.size());
|
||||
|
@ -99,9 +102,11 @@ abstract class AbstractKnnVectorQuery extends Query {
|
|||
}
|
||||
|
||||
private TopDocs searchLeaf(
|
||||
LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager)
|
||||
LeafReaderContext ctx,
|
||||
Weight filterWeight,
|
||||
TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager)
|
||||
throws IOException {
|
||||
TopDocs results = getLeafResults(ctx, filterWeight, knnCollectorManager);
|
||||
TopDocs results = getLeafResults(ctx, filterWeight, timeLimitingKnnCollectorManager);
|
||||
if (ctx.docBase > 0) {
|
||||
for (ScoreDoc scoreDoc : results.scoreDocs) {
|
||||
scoreDoc.doc += ctx.docBase;
|
||||
|
@ -111,13 +116,15 @@ abstract class AbstractKnnVectorQuery extends Query {
|
|||
}
|
||||
|
||||
private TopDocs getLeafResults(
|
||||
LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager)
|
||||
LeafReaderContext ctx,
|
||||
Weight filterWeight,
|
||||
TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager)
|
||||
throws IOException {
|
||||
Bits liveDocs = ctx.reader().getLiveDocs();
|
||||
int maxDoc = ctx.reader().maxDoc();
|
||||
|
||||
if (filterWeight == null) {
|
||||
return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, knnCollectorManager);
|
||||
return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, timeLimitingKnnCollectorManager);
|
||||
}
|
||||
|
||||
Scorer scorer = filterWeight.scorer(ctx);
|
||||
|
@ -127,21 +134,24 @@ abstract class AbstractKnnVectorQuery extends Query {
|
|||
|
||||
BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, maxDoc);
|
||||
final int cost = acceptDocs.cardinality();
|
||||
QueryTimeout queryTimeout = timeLimitingKnnCollectorManager.getQueryTimeout();
|
||||
|
||||
if (cost <= k) {
|
||||
// If there are <= k possible matches, short-circuit and perform exact search, since HNSW
|
||||
// must always visit at least k documents
|
||||
return exactSearch(ctx, new BitSetIterator(acceptDocs, cost));
|
||||
return exactSearch(ctx, new BitSetIterator(acceptDocs, cost), queryTimeout);
|
||||
}
|
||||
|
||||
// Perform the approximate kNN search
|
||||
// We pass cost + 1 here to account for the edge case when we explore exactly cost vectors
|
||||
TopDocs results = approximateSearch(ctx, acceptDocs, cost + 1, knnCollectorManager);
|
||||
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) {
|
||||
TopDocs results = approximateSearch(ctx, acceptDocs, cost + 1, timeLimitingKnnCollectorManager);
|
||||
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO
|
||||
// Return partial results only when timeout is met
|
||||
|| (queryTimeout != null && queryTimeout.shouldExit())) {
|
||||
return results;
|
||||
} else {
|
||||
// We stopped the kNN search because it visited too many nodes, so fall back to exact search
|
||||
return exactSearch(ctx, new BitSetIterator(acceptDocs, cost));
|
||||
return exactSearch(ctx, new BitSetIterator(acceptDocs, cost), queryTimeout);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -178,7 +188,8 @@ abstract class AbstractKnnVectorQuery extends Query {
|
|||
throws IOException;
|
||||
|
||||
// We allow this to be overridden so that tests can check what search strategy is used
|
||||
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
|
||||
protected TopDocs exactSearch(
|
||||
LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout)
|
||||
throws IOException {
|
||||
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
|
||||
if (fi == null || fi.getVectorDimension() == 0) {
|
||||
|
@ -192,9 +203,16 @@ abstract class AbstractKnnVectorQuery extends Query {
|
|||
}
|
||||
final int queueSize = Math.min(k, Math.toIntExact(acceptIterator.cost()));
|
||||
HitQueue queue = new HitQueue(queueSize, true);
|
||||
TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO;
|
||||
ScoreDoc topDoc = queue.top();
|
||||
int doc;
|
||||
while ((doc = acceptIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
|
||||
// Mark results as partial if timeout is met
|
||||
if (queryTimeout != null && queryTimeout.shouldExit()) {
|
||||
relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
|
||||
break;
|
||||
}
|
||||
|
||||
boolean advanced = vectorScorer.advanceExact(doc);
|
||||
assert advanced;
|
||||
|
||||
|
@ -216,7 +234,7 @@ abstract class AbstractKnnVectorQuery extends Query {
|
|||
topScoreDocs[i] = queue.pop();
|
||||
}
|
||||
|
||||
TotalHits totalHits = new TotalHits(acceptIterator.cost(), TotalHits.Relation.EQUAL_TO);
|
||||
TotalHits totalHits = new TotalHits(acceptIterator.cost(), relation);
|
||||
return new TopDocs(totalHits, topScoreDocs);
|
||||
}
|
||||
|
||||
|
|
|
@ -483,6 +483,14 @@ public class IndexSearcher {
|
|||
return search(query, manager);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the configured {@link QueryTimeout} for all searches that run through this {@link
|
||||
* IndexSearcher}, or {@code null} if not set.
|
||||
*/
|
||||
public QueryTimeout getTimeout() {
|
||||
return this.queryTimeout;
|
||||
}
|
||||
|
||||
/** Set a {@link QueryTimeout} for all searches that run through this {@link IndexSearcher}. */
|
||||
public void setTimeout(QueryTimeout queryTimeout) {
|
||||
this.queryTimeout = queryTimeout;
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.search;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.QueryTimeout;
|
||||
import org.apache.lucene.search.knn.KnnCollectorManager;
|
||||
|
||||
/** A {@link KnnCollectorManager} that collects results with a timeout. */
|
||||
public class TimeLimitingKnnCollectorManager implements KnnCollectorManager {
|
||||
private final KnnCollectorManager delegate;
|
||||
private final QueryTimeout queryTimeout;
|
||||
|
||||
public TimeLimitingKnnCollectorManager(KnnCollectorManager delegate, QueryTimeout timeout) {
|
||||
this.delegate = delegate;
|
||||
this.queryTimeout = timeout;
|
||||
}
|
||||
|
||||
/** Get the configured {@link QueryTimeout} for terminating graph and exact searches. */
|
||||
public QueryTimeout getQueryTimeout() {
|
||||
return queryTimeout;
|
||||
}
|
||||
|
||||
@Override
|
||||
public KnnCollector newCollector(int visitedLimit, LeafReaderContext context) throws IOException {
|
||||
KnnCollector collector = delegate.newCollector(visitedLimit, context);
|
||||
if (queryTimeout == null) {
|
||||
return collector;
|
||||
}
|
||||
return new KnnCollector() {
|
||||
@Override
|
||||
public boolean earlyTerminated() {
|
||||
return queryTimeout.shouldExit() || collector.earlyTerminated();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void incVisitedCount(int count) {
|
||||
collector.incVisitedCount(count);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long visitedCount() {
|
||||
return collector.visitedCount();
|
||||
}
|
||||
|
||||
@Override
|
||||
public long visitLimit() {
|
||||
return collector.visitLimit();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int k() {
|
||||
return collector.k();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean collect(int docId, float similarity) {
|
||||
return collector.collect(docId, similarity);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float minCompetitiveSimilarity() {
|
||||
return collector.minCompetitiveSimilarity();
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs topDocs() {
|
||||
TopDocs docs = collector.topDocs();
|
||||
|
||||
// Mark results as partial if timeout is met
|
||||
TotalHits.Relation relation =
|
||||
queryTimeout.shouldExit()
|
||||
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
|
||||
: docs.totalHits.relation;
|
||||
|
||||
return new TopDocs(new TotalHits(docs.totalHits.value, relation), docs.scoreDocs);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
|
@ -39,6 +39,7 @@ import org.apache.lucene.index.IndexWriter;
|
|||
import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.QueryTimeout;
|
||||
import org.apache.lucene.index.StoredFields;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
|
@ -765,6 +766,34 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
/** Test that the query times out correctly. */
|
||||
public void testTimeout() throws IOException {
|
||||
try (Directory indexStore =
|
||||
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
|
||||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
|
||||
AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 2);
|
||||
AbstractKnnVectorQuery exactQuery =
|
||||
getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 10, new MatchAllDocsQuery());
|
||||
|
||||
assertEquals(2, searcher.count(query)); // Expect some results without timeout
|
||||
assertEquals(3, searcher.count(exactQuery)); // Same for exact search
|
||||
|
||||
searcher.setTimeout(() -> true); // Immediately timeout
|
||||
assertEquals(0, searcher.count(query)); // Expect no results with the timeout
|
||||
assertEquals(0, searcher.count(exactQuery)); // Same for exact search
|
||||
|
||||
searcher.setTimeout(new CountingQueryTimeout(1)); // Only score 1 doc
|
||||
// Note: This depends on the HNSW graph having just one layer,
|
||||
// would be 0 in case of multiple layers
|
||||
assertEquals(1, searcher.count(query)); // Expect only 1 result
|
||||
|
||||
searcher.setTimeout(new CountingQueryTimeout(1)); // Only score 1 doc
|
||||
assertEquals(1, searcher.count(exactQuery)); // Expect only 1 result
|
||||
}
|
||||
}
|
||||
|
||||
/** Creates a new directory and adds documents with the given vectors as kNN vector fields */
|
||||
Directory getIndexStore(String field, float[]... contents) throws IOException {
|
||||
return getIndexStore(field, VectorSimilarityFunction.EUCLIDEAN, contents);
|
||||
|
@ -1006,4 +1035,21 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static class CountingQueryTimeout implements QueryTimeout {
|
||||
private int remaining;
|
||||
|
||||
public CountingQueryTimeout(int count) {
|
||||
remaining = count;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean shouldExit() {
|
||||
if (remaining > 0) {
|
||||
remaining--;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.apache.lucene.document.KnnByteVectorField;
|
|||
import org.apache.lucene.index.DirectoryReader;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.QueryTimeout;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.util.TestVectorUtil;
|
||||
|
@ -109,7 +110,8 @@ public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
|
|||
}
|
||||
|
||||
@Override
|
||||
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) {
|
||||
protected TopDocs exactSearch(
|
||||
LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) {
|
||||
throw new UnsupportedOperationException("exact search is not supported");
|
||||
}
|
||||
|
||||
|
|
|
@ -34,6 +34,7 @@ import org.apache.lucene.index.IndexReader;
|
|||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.QueryTimeout;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||
|
@ -258,7 +259,8 @@ public class TestKnnFloatVectorQuery extends BaseKnnVectorQueryTestCase {
|
|||
}
|
||||
|
||||
@Override
|
||||
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) {
|
||||
protected TopDocs exactSearch(
|
||||
LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) {
|
||||
throw new UnsupportedOperationException("exact search is not supported");
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ import java.util.Objects;
|
|||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.QueryTimeout;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.HitQueue;
|
||||
|
@ -77,7 +78,8 @@ public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery {
|
|||
}
|
||||
|
||||
@Override
|
||||
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
|
||||
protected TopDocs exactSearch(
|
||||
LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout)
|
||||
throws IOException {
|
||||
ByteVectorValues byteVectorValues = context.reader().getByteVectorValues(field);
|
||||
if (byteVectorValues == null) {
|
||||
|
@ -100,8 +102,15 @@ public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery {
|
|||
fi.getVectorSimilarityFunction());
|
||||
final int queueSize = Math.min(k, Math.toIntExact(acceptIterator.cost()));
|
||||
HitQueue queue = new HitQueue(queueSize, true);
|
||||
TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO;
|
||||
ScoreDoc topDoc = queue.top();
|
||||
while (vectorScorer.nextParent() != DocIdSetIterator.NO_MORE_DOCS) {
|
||||
// Mark results as partial if timeout is met
|
||||
if (queryTimeout != null && queryTimeout.shouldExit()) {
|
||||
relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
|
||||
break;
|
||||
}
|
||||
|
||||
float score = vectorScorer.score();
|
||||
if (score > topDoc.score) {
|
||||
topDoc.score = score;
|
||||
|
@ -120,7 +129,7 @@ public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery {
|
|||
topScoreDocs[i] = queue.pop();
|
||||
}
|
||||
|
||||
TotalHits totalHits = new TotalHits(acceptIterator.cost(), TotalHits.Relation.EQUAL_TO);
|
||||
TotalHits totalHits = new TotalHits(acceptIterator.cost(), relation);
|
||||
return new TopDocs(totalHits, topScoreDocs);
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ import java.util.Objects;
|
|||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.QueryTimeout;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.HitQueue;
|
||||
|
@ -77,7 +78,8 @@ public class DiversifyingChildrenFloatKnnVectorQuery extends KnnFloatVectorQuery
|
|||
}
|
||||
|
||||
@Override
|
||||
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
|
||||
protected TopDocs exactSearch(
|
||||
LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout)
|
||||
throws IOException {
|
||||
FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(field);
|
||||
if (floatVectorValues == null) {
|
||||
|
@ -100,8 +102,15 @@ public class DiversifyingChildrenFloatKnnVectorQuery extends KnnFloatVectorQuery
|
|||
fi.getVectorSimilarityFunction());
|
||||
final int queueSize = Math.min(k, Math.toIntExact(acceptIterator.cost()));
|
||||
HitQueue queue = new HitQueue(queueSize, true);
|
||||
TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO;
|
||||
ScoreDoc topDoc = queue.top();
|
||||
while (vectorScorer.nextParent() != DocIdSetIterator.NO_MORE_DOCS) {
|
||||
// Mark results as partial if timeout is met
|
||||
if (queryTimeout != null && queryTimeout.shouldExit()) {
|
||||
relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
|
||||
break;
|
||||
}
|
||||
|
||||
float score = vectorScorer.score();
|
||||
if (score > topDoc.score) {
|
||||
topDoc.score = score;
|
||||
|
@ -120,7 +129,7 @@ public class DiversifyingChildrenFloatKnnVectorQuery extends KnnFloatVectorQuery
|
|||
topScoreDocs[i] = queue.pop();
|
||||
}
|
||||
|
||||
TotalHits totalHits = new TotalHits(acceptIterator.cost(), TotalHits.Relation.EQUAL_TO);
|
||||
TotalHits totalHits = new TotalHits(acceptIterator.cost(), relation);
|
||||
return new TopDocs(totalHits, topScoreDocs);
|
||||
}
|
||||
|
||||
|
|
|
@ -33,6 +33,7 @@ import org.apache.lucene.index.DirectoryReader;
|
|||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.QueryTimeout;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
|
@ -283,6 +284,36 @@ abstract class ParentBlockJoinKnnVectorQueryTestCase extends LuceneTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
/** Test that the query times out correctly. */
|
||||
public void testTimeout() throws IOException {
|
||||
try (Directory indexStore =
|
||||
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
|
||||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
BitSetProducer parentFilter = parentFilter(reader);
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
|
||||
Query query = getParentJoinKnnQuery("field", new float[] {1, 2}, null, 2, parentFilter);
|
||||
Query exactQuery =
|
||||
getParentJoinKnnQuery(
|
||||
"field", new float[] {1, 2}, new MatchAllDocsQuery(), 10, parentFilter);
|
||||
|
||||
assertEquals(2, searcher.count(query)); // Expect some results without timeout
|
||||
assertEquals(3, searcher.count(exactQuery)); // Same for exact search
|
||||
|
||||
searcher.setTimeout(() -> true); // Immediately timeout
|
||||
assertEquals(0, searcher.count(query)); // Expect no results with the timeout
|
||||
assertEquals(0, searcher.count(exactQuery)); // Same for exact search
|
||||
|
||||
searcher.setTimeout(new CountingQueryTimeout(1)); // Only score 1 parent
|
||||
// Note: This depends on the HNSW graph having just one layer,
|
||||
// would be 0 in case of multiple layers
|
||||
assertEquals(1, searcher.count(query)); // Expect only 1 result
|
||||
|
||||
searcher.setTimeout(new CountingQueryTimeout(1)); // Only score 1 parent
|
||||
assertEquals(1, searcher.count(exactQuery)); // Expect only 1 result
|
||||
}
|
||||
}
|
||||
|
||||
Directory getIndexStore(String field, float[]... contents) throws IOException {
|
||||
Directory indexStore = newDirectory();
|
||||
RandomIndexWriter writer =
|
||||
|
@ -352,4 +383,21 @@ abstract class ParentBlockJoinKnnVectorQueryTestCase extends LuceneTestCase {
|
|||
assertEquals(idToScore.get(actualId), scorer.score(), 0.0001);
|
||||
}
|
||||
}
|
||||
|
||||
private static class CountingQueryTimeout implements QueryTimeout {
|
||||
private int remaining;
|
||||
|
||||
public CountingQueryTimeout(int count) {
|
||||
remaining = count;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean shouldExit() {
|
||||
if (remaining > 0) {
|
||||
remaining--;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue