LUCENE-10382: Support filtering in KnnVectorQuery (#656)

This PR adds support for a query filter in KnnVectorQuery. First, we gather the
query results for each leaf as a bit set. Then the HNSW search skips over the
non-matching documents (using the same approach as for live docs). To prevent
HNSW search from visiting too many documents when the filter is very selective,
we short-circuit if HNSW has already visited more than the number of documents
that match the filter, and execute an exact search instead. This bounds the
number of visited documents at roughly 2x the cost of just running the exact
filter, while in most cases HNSW completes successfully and does a lot better.

Co-authored-by: Joel Bernstein <jbernste@apache.org>
This commit is contained in:
Julie Tibshirani 2022-02-17 11:35:25 -08:00 committed by GitHub
parent c132bbf677
commit 8ca372573d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 740 additions and 81 deletions

View File

@ -158,6 +158,9 @@ New Features
* LUCENE-10415: FunctionScoreQuery and IndexOrDocValuesQuery delegate Weight#count. (Ignacio Vera)
* LUCENE-10382: Add support for filtering in KnnVectorQuery. This allows for finding the
nearest k documents that also match a query. (Julie Tibshirani, Joel Bernstein)
Improvements
---------------------

View File

@ -229,7 +229,8 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
}
@Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs) throws IOException {
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry.size() == 0) {

View File

@ -144,7 +144,8 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
}
@Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs) throws IOException {
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
VectorValues values = getVectorValues(field);
if (target.length != values.dimension()) {
throw new IllegalArgumentException(

View File

@ -100,7 +100,8 @@ public abstract class KnnVectorsFormat implements NamedSPILoader.NamedSPI {
}
@Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs) {
public TopDocs search(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
return TopDocsCollector.EMPTY_TOPDOCS;
}

View File

@ -21,7 +21,9 @@ import java.io.Closeable;
import java.io.IOException;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.Bits;
@ -58,6 +60,12 @@ public abstract class KnnVectorsReader implements Closeable, Accountable {
* true k closest neighbors. For large values of k (for example when k is close to the total
* number of documents), the search may also retrieve fewer than k documents.
*
* <p>The returned {@link TopDocs} will contain a {@link ScoreDoc} for each nearest neighbor, in
* order of their similarity to the query vector (decreasing scores). The {@link TotalHits}
* contains the number of documents visited during the search. If the search stopped early because
* it hit {@code visitedLimit}, it is indicated through the relation {@code
* TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO}.
*
* <p>The behavior is undefined if the given field doesn't have KNN vectors enabled on its {@link
* FieldInfo}. The return value is never {@code null}.
*
@ -66,10 +74,11 @@ public abstract class KnnVectorsReader implements Closeable, Accountable {
* @param k the number of docs to return
* @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null}
* if they are all allowed to match.
* @param visitedLimit the maximum number of nodes that the search is allowed to visit
* @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores.
*/
public abstract TopDocs search(String field, float[] target, int k, Bits acceptDocs)
throws IOException;
public abstract TopDocs search(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
/**
* Returns an instance optimized for merging. This instance may only be consumed in the thread

View File

@ -86,7 +86,8 @@ public abstract class KnnVectorsWriter implements Closeable {
}
@Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs) {
public TopDocs search(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
throw new UnsupportedOperationException();
}
});

View File

@ -219,7 +219,8 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
}
@Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs) throws IOException {
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry.size() == 0) {
@ -228,8 +229,8 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
// bound k by total number of vectors to prevent oversizing data structures
k = Math.min(k, fieldEntry.size());
OffHeapVectorValues vectorValues = getOffHeapVectorValues(fieldEntry);
NeighborQueue results =
HnswGraphSearcher.search(
target,
@ -237,7 +238,8 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
vectorValues,
fieldEntry.similarityFunction,
getGraph(fieldEntry),
getAcceptOrds(acceptDocs, fieldEntry));
getAcceptOrds(acceptDocs, fieldEntry),
visitedLimit);
int i = 0;
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
@ -247,11 +249,12 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
results.pop();
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc(node), score);
}
// always return >= the case where we can assert == is only when there are fewer than topK
// vectors in the index
return new TopDocs(
new TotalHits(results.visitedCount(), TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO),
scoreDocs);
TotalHits.Relation relation =
results.incomplete()
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
}
private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException {

View File

@ -263,12 +263,13 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
}
@Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs) throws IOException {
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
KnnVectorsReader knnVectorsReader = fields.get(field);
if (knnVectorsReader == null) {
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
} else {
return knnVectorsReader.search(field, target, k, acceptDocs);
return knnVectorsReader.search(field, target, k, acceptDocs, visitedLimit);
}
}

View File

@ -223,8 +223,8 @@ public abstract class CodecReader extends LeafReader {
}
@Override
public final TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs)
throws IOException {
public final TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
ensureOpen();
FieldInfo fi = getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
@ -232,7 +232,7 @@ public abstract class CodecReader extends LeafReader {
return null;
}
return getVectorReader().search(field, target, k, acceptDocs);
return getVectorReader().search(field, target, k, acceptDocs, visitedLimit);
}
@Override

View File

@ -53,8 +53,8 @@ abstract class DocValuesLeafReader extends LeafReader {
}
@Override
public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs)
throws IOException {
public TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
throw new UnsupportedOperationException();
}

View File

@ -352,9 +352,9 @@ public abstract class FilterLeafReader extends LeafReader {
}
@Override
public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs)
throws IOException {
return in.searchNearestVectors(field, target, k, acceptDocs);
public TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
}
@Override

View File

@ -17,7 +17,9 @@
package org.apache.lucene.index;
import java.io.IOException;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.Bits;
/**
@ -207,20 +209,31 @@ public abstract class LeafReader extends IndexReader {
/**
* Return the k nearest neighbor documents as determined by comparison of their vector values for
* this field, to the given vector, by the field's search strategy. If the search strategy is
* reversed, lower values indicate nearer vectors, otherwise higher scores indicate nearer
* vectors. Unlike relevance scores, vector scores may be negative.
* this field, to the given vector, by the field's similarity function. The score of each document
* is derived from the vector similarity in a way that ensures scores are positive and that a
* larger score corresponds to a higher ranking.
*
* <p>The search is allowed to be approximate, meaning the results are not guaranteed to be the
* true k closest neighbors. For large values of k (for example when k is close to the total
* number of documents), the search may also retrieve fewer than k documents.
*
* <p>The returned {@link TopDocs} will contain a {@link ScoreDoc} for each nearest neighbor,
* sorted in order of their similarity to the query vector (decreasing scores). The {@link
* TotalHits} contains the number of documents visited during the search. If the search stopped
* early because it hit {@code visitedLimit}, it is indicated through the relation {@code
* TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO}.
*
* @param field the vector field to search
* @param target the vector-valued query
* @param k the number of docs to return
* @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null}
* if they are all allowed to match.
* @param visitedLimit the maximum number of nodes that the search is allowed to visit
* @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores.
* @lucene.experimental
*/
public abstract TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs)
throws IOException;
public abstract TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
/**
* Get the {@link FieldInfos} describing all fields in this reader.

View File

@ -393,11 +393,14 @@ public class ParallelLeafReader extends LeafReader {
}
@Override
public TopDocs searchNearestVectors(String fieldName, float[] target, int k, Bits acceptDocs)
public TopDocs searchNearestVectors(
String fieldName, float[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
ensureOpen();
LeafReader reader = fieldToReader.get(fieldName);
return reader == null ? null : reader.searchNearestVectors(fieldName, target, k, acceptDocs);
return reader == null
? null
: reader.searchNearestVectors(fieldName, target, k, acceptDocs, visitedLimit);
}
@Override

View File

@ -167,9 +167,9 @@ public final class SlowCodecReaderWrapper {
}
@Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs)
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
return reader.searchNearestVectors(field, target, k, acceptDocs);
return reader.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
}
@Override

View File

@ -384,7 +384,8 @@ public final class SortingCodecReader extends FilterCodecReader {
}
@Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs) {
public TopDocs search(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
throw new UnsupportedOperationException();
}

View File

@ -30,7 +30,7 @@ import org.apache.lucene.util.BytesRef;
public abstract class VectorValues extends DocIdSetIterator {
/** The maximum length of a vector */
public static int MAX_DIMENSIONS = 1024;
public static final int MAX_DIMENSIONS = 1024;
/** Sole constructor */
protected VectorValues() {}

View File

@ -135,7 +135,8 @@ class VectorValuesWriter {
}
@Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs)
public TopDocs search(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
throw new UnsupportedOperationException();
}

View File

@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import java.util.Objects;
import org.apache.lucene.index.LeafReaderContext;
/**
* A {@link Query} that matches documents that contain a {@link
* org.apache.lucene.document.KnnVectorField}.
*/
public class KnnVectorFieldExistsQuery extends Query {
private final String field;
/** Create a query that will match documents which have a value for the given {@code field}. */
public KnnVectorFieldExistsQuery(String field) {
this.field = Objects.requireNonNull(field);
}
public String getField() {
return field;
}
@Override
public boolean equals(Object other) {
return sameClassAs(other) && field.equals(((KnnVectorFieldExistsQuery) other).field);
}
@Override
public int hashCode() {
return 31 * classHash() + field.hashCode();
}
@Override
public String toString(String field) {
return "KnnVectorFieldExistsQuery [field=" + this.field + "]";
}
@Override
public void visit(QueryVisitor visitor) {
if (visitor.acceptField(field)) {
visitor.visitLeaf(this);
}
}
@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) {
return new ConstantScoreWeight(this, boost) {
@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
DocIdSetIterator iterator = context.reader().getVectorValues(field);
if (iterator == null) {
return null;
}
return new ConstantScoreScorer(this, score(), scoreMode, iterator);
}
@Override
public boolean isCacheable(LeafReaderContext ctx) {
return true;
}
};
}
}

View File

@ -24,19 +24,36 @@ import java.util.Comparator;
import java.util.Objects;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
/** Uses {@link KnnVectorsReader#search} to perform nearest neighbour search. */
/**
* Uses {@link KnnVectorsReader#search} to perform nearest neighbour search.
*
* <p>This query also allows for performing a kNN search subject to a filter. In this case, it first
* executes the filter for each leaf, then chooses a strategy dynamically:
*
* <ul>
* <li>If the filter cost is less than k, just execute an exact search
* <li>Otherwise run a kNN search subject to the filter
* <li>If the kNN search visits too many vectors without completing, stop and run an exact search
* </ul>
*/
public class KnnVectorQuery extends Query {
private static final TopDocs NO_RESULTS =
new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
private final String field;
private final float[] target;
private final int k;
private final Query filter;
/**
* Find the <code>k</code> nearest documents to the target vector according to the vectors in the
@ -48,19 +65,53 @@ public class KnnVectorQuery extends Query {
* @throws IllegalArgumentException if <code>k</code> is less than 1
*/
public KnnVectorQuery(String field, float[] target, int k) {
this(field, target, k, null);
}
/**
* Find the <code>k</code> nearest documents to the target vector according to the vectors in the
* given field. <code>target</code> vector.
*
* @param field a field that has been indexed as a {@link KnnVectorField}.
* @param target the target of the search
* @param k the number of documents to find
* @param filter a filter applied before the vector search
* @throws IllegalArgumentException if <code>k</code> is less than 1
*/
public KnnVectorQuery(String field, float[] target, int k, Query filter) {
this.field = field;
this.target = target;
this.k = k;
if (k < 1) {
throw new IllegalArgumentException("k must be at least 1, got: " + k);
}
this.filter = filter;
}
@Override
public Query rewrite(IndexReader reader) throws IOException {
TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
BitSetCollector filterCollector = null;
if (filter != null) {
filterCollector = new BitSetCollector(reader.leaves().size());
IndexSearcher indexSearcher = new IndexSearcher(reader);
BooleanQuery booleanQuery =
new BooleanQuery.Builder()
.add(filter, BooleanClause.Occur.FILTER)
.add(new KnnVectorFieldExistsQuery(field), BooleanClause.Occur.FILTER)
.build();
indexSearcher.search(booleanQuery, filterCollector);
}
for (LeafReaderContext ctx : reader.leaves()) {
perLeafResults[ctx.ord] = searchLeaf(ctx, k);
TopDocs results = searchLeaf(ctx, filterCollector);
if (ctx.docBase > 0) {
for (ScoreDoc scoreDoc : results.scoreDocs) {
scoreDoc.doc += ctx.docBase;
}
}
perLeafResults[ctx.ord] = results;
}
// Merge sort the results
TopDocs topK = TopDocs.merge(k, perLeafResults);
@ -70,18 +121,126 @@ public class KnnVectorQuery extends Query {
return createRewrittenQuery(reader, topK);
}
private TopDocs searchLeaf(LeafReaderContext ctx, int kPerLeaf) throws IOException {
Bits liveDocs = ctx.reader().getLiveDocs();
TopDocs results = ctx.reader().searchNearestVectors(field, target, kPerLeaf, liveDocs);
if (results == null) {
return NO_RESULTS;
}
if (ctx.docBase > 0) {
for (ScoreDoc scoreDoc : results.scoreDocs) {
scoreDoc.doc += ctx.docBase;
private TopDocs searchLeaf(LeafReaderContext ctx, BitSetCollector filterCollector)
throws IOException {
if (filterCollector == null) {
Bits acceptDocs = ctx.reader().getLiveDocs();
return approximateSearch(ctx, acceptDocs, Integer.MAX_VALUE);
} else {
BitSetIterator filterIterator = filterCollector.getIterator(ctx.ord);
if (filterIterator == null || filterIterator.cost() == 0) {
return NO_RESULTS;
}
if (filterIterator.cost() <= k) {
// If there are <= k possible matches, short-circuit and perform exact search, since HNSW
// must always visit at least k documents
return exactSearch(ctx, filterIterator);
}
// Perform the approximate kNN search
Bits acceptDocs =
filterIterator.getBitSet(); // The filter iterator already incorporates live docs
int visitedLimit = (int) filterIterator.cost();
TopDocs results = approximateSearch(ctx, acceptDocs, visitedLimit);
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) {
return results;
} else {
// We stopped the kNN search because it visited too many nodes, so fall back to exact search
return exactSearch(ctx, filterIterator);
}
}
return results;
}
private TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit)
throws IOException {
TopDocs results =
context.reader().searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
return results != null ? results : NO_RESULTS;
}
// We allow this to be overridden so that tests can check what search strategy is used
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
throws IOException {
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
// The field does not exist or does not index vectors
return NO_RESULTS;
}
VectorSimilarityFunction similarityFunction = fi.getVectorSimilarityFunction();
VectorValues vectorValues = context.reader().getVectorValues(field);
HitQueue queue = new HitQueue(k, true);
ScoreDoc topDoc = queue.top();
int doc;
while ((doc = acceptIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
int vectorDoc = vectorValues.advance(doc);
assert vectorDoc == doc;
float[] vector = vectorValues.vectorValue();
float score = similarityFunction.convertToScore(similarityFunction.compare(vector, target));
if (score >= topDoc.score) {
topDoc.score = score;
topDoc.doc = doc;
topDoc = queue.updateTop();
}
}
// Remove any remaining sentinel values
while (queue.size() > 0 && queue.top().score < 0) {
queue.pop();
}
ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()];
for (int i = topScoreDocs.length - 1; i >= 0; i--) {
topScoreDocs[i] = queue.pop();
}
TotalHits totalHits = new TotalHits(acceptIterator.cost(), TotalHits.Relation.EQUAL_TO);
return new TopDocs(totalHits, topScoreDocs);
}
private static class BitSetCollector extends SimpleCollector {
private final BitSet[] bitSets;
private final int[] cost;
private int ord;
private BitSetCollector(int numLeaves) {
this.bitSets = new BitSet[numLeaves];
this.cost = new int[bitSets.length];
}
/**
* Return an iterator whose {@link BitSet} contains the matching documents, and whose {@link
* BitSetIterator#cost()} is the exact cardinality. If the leaf was never visited, then return
* null.
*/
public BitSetIterator getIterator(int contextOrd) {
if (bitSets[contextOrd] == null) {
return null;
}
return new BitSetIterator(bitSets[contextOrd], cost[contextOrd]);
}
@Override
public void collect(int doc) throws IOException {
bitSets[ord].set(doc);
cost[ord]++;
}
@Override
protected void doSetNextReader(LeafReaderContext context) throws IOException {
bitSets[context.ord] = new FixedBitSet(context.reader().maxDoc());
ord = context.ord;
}
@Override
public org.apache.lucene.search.ScoreMode scoreMode() {
return org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES;
}
}
private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {

View File

@ -151,13 +151,12 @@ public final class HnswGraphBuilder {
// for levels > nodeLevel search with topk = 1
for (int level = curMaxLevel; level > nodeLevel; level--) {
candidates = graphSearcher.searchLevel(value, 1, level, eps, vectorValues, hnsw, null);
candidates = graphSearcher.searchLevel(value, 1, level, eps, vectorValues, hnsw);
eps = new int[] {candidates.pop()};
}
// for levels <= nodeLevel search with topk = beamWidth, and add connections
for (int level = Math.min(nodeLevel, curMaxLevel); level >= 0; level--) {
candidates =
graphSearcher.searchLevel(value, beamWidth, level, eps, vectorValues, hnsw, null);
candidates = graphSearcher.searchLevel(value, beamWidth, level, eps, vectorValues, hnsw);
eps = candidates.nodes();
hnsw.addNode(level, node);
addDiverseNeighbors(level, node, candidates);

View File

@ -65,6 +65,7 @@ public final class HnswGraphSearcher {
* graph.
* @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or
* {@code null} if they are all allowed to match.
* @param visitedLimit the maximum number of nodes that the search is allowed to visit
* @return a priority queue holding the closest neighbors found
*/
public static NeighborQueue search(
@ -73,7 +74,8 @@ public final class HnswGraphSearcher {
RandomAccessVectorValues vectors,
VectorSimilarityFunction similarityFunction,
HnswGraph graph,
Bits acceptOrds)
Bits acceptOrds,
int visitedLimit)
throws IOException {
HnswGraphSearcher graphSearcher =
new HnswGraphSearcher(
@ -82,16 +84,25 @@ public final class HnswGraphSearcher {
new SparseFixedBitSet(vectors.size()));
NeighborQueue results;
int[] eps = new int[] {graph.entryNode()};
int numVisited = 0;
for (int level = graph.numLevels() - 1; level >= 1; level--) {
results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null);
results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null, visitedLimit);
eps[0] = results.pop();
numVisited += results.visitedCount();
visitedLimit -= results.visitedCount();
}
results = graphSearcher.searchLevel(query, topK, 0, eps, vectors, graph, acceptOrds);
results =
graphSearcher.searchLevel(query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit);
results.setVisitedCount(results.visitedCount() + numVisited);
return results;
}
/**
* Searches for the nearest neighbors of a query vector in a given level
* Searches for the nearest neighbors of a query vector in a given level.
*
* <p>If the search stops early because it reaches the visited nodes limit, then the results will
* be marked incomplete through {@link NeighborQueue#incomplete()}.
*
* @param query search query vector
* @param topK the number of nearest to query results to return
@ -99,23 +110,34 @@ public final class HnswGraphSearcher {
* @param eps the entry points for search at this level expressed as level 0th ordinals
* @param vectors vector values
* @param graph the graph values
* @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or
* {@code null} if they are all allowed to match.
* @return a priority queue holding the closest neighbors found
*/
NeighborQueue searchLevel(
float[] query,
int topK,
int level,
final int[] eps,
RandomAccessVectorValues vectors,
HnswGraph graph)
throws IOException {
return searchLevel(query, topK, level, eps, vectors, graph, null, Integer.MAX_VALUE);
}
private NeighborQueue searchLevel(
float[] query,
int topK,
int level,
final int[] eps,
RandomAccessVectorValues vectors,
HnswGraph graph,
Bits acceptOrds)
Bits acceptOrds,
int visitedLimit)
throws IOException {
int size = graph.size();
NeighborQueue results = new NeighborQueue(topK, similarityFunction.reversed);
clearScratchState();
int numVisited = 0;
for (int ep : eps) {
if (visited.getAndSet(ep) == false) {
float score = similarityFunction.compare(query, vectors.vectorValue(ep));
@ -123,6 +145,7 @@ public final class HnswGraphSearcher {
if (acceptOrds == null || acceptOrds.get(ep)) {
results.add(ep, score);
}
numVisited++;
}
}
@ -138,6 +161,12 @@ public final class HnswGraphSearcher {
if (bound.check(topCandidateScore)) {
break;
}
if (numVisited >= visitedLimit) {
results.markIncomplete();
break;
}
int topCandidateNode = candidates.pop();
graph.seek(level, topCandidateNode);
int friendOrd;
@ -148,6 +177,7 @@ public final class HnswGraphSearcher {
}
float score = similarityFunction.compare(query, vectors.vectorValue(friendOrd));
numVisited++;
if (bound.check(score) == false) {
candidates.add(friendOrd, score);
if (acceptOrds == null || acceptOrds.get(friendOrd)) {
@ -161,7 +191,7 @@ public final class HnswGraphSearcher {
while (results.size() > topK) {
results.pop();
}
results.setVisitedCount(visited.approximateCardinality());
results.setVisitedCount(numVisited);
return results;
}

View File

@ -53,6 +53,8 @@ public class NeighborQueue {
// Used to track the number of neighbors visited during a single graph traversal
private int visitedCount;
// Whether the search stopped early because it reached the visited nodes limit
private boolean incomplete;
public NeighborQueue(int initialSize, boolean reversed) {
this.heap = new LongHeap(initialSize);
@ -128,6 +130,14 @@ public class NeighborQueue {
this.visitedCount = visitedCount;
}
public boolean incomplete() {
return incomplete;
}
public void markIncomplete() {
this.incomplete = true;
}
@Override
public String toString() {
return "Neighbors[" + heap.size() + "]";

View File

@ -104,11 +104,13 @@ public class TestPerFieldKnnVectorsFormat extends BaseKnnVectorsFormatTestCase {
try (IndexReader ireader = DirectoryReader.open(directory)) {
LeafReader reader = ireader.leaves().get(0).reader();
TopDocs hits1 =
reader.searchNearestVectors("field1", new float[] {1, 2, 3}, 10, reader.getLiveDocs());
reader.searchNearestVectors(
"field1", new float[] {1, 2, 3}, 10, reader.getLiveDocs(), Integer.MAX_VALUE);
assertEquals(1, hits1.scoreDocs.length);
TopDocs hits2 =
reader.searchNearestVectors("field2", new float[] {1, 2, 3}, 10, reader.getLiveDocs());
reader.searchNearestVectors(
"field2", new float[] {1, 2, 3}, 10, reader.getLiveDocs(), Integer.MAX_VALUE);
assertEquals(1, hits2.scoreDocs.length);
}
}

View File

@ -397,7 +397,9 @@ public class TestKnnGraph extends LuceneTestCase {
TopDocs[] results = new TopDocs[reader.leaves().size()];
for (LeafReaderContext ctx : reader.leaves()) {
Bits liveDocs = ctx.reader().getLiveDocs();
results[ctx.ord] = ctx.reader().searchNearestVectors(KNN_GRAPH_FIELD, vector, k, liveDocs);
results[ctx.ord] =
ctx.reader()
.searchNearestVectors(KNN_GRAPH_FIELD, vector, k, liveDocs, Integer.MAX_VALUE);
if (ctx.docBase > 0) {
for (ScoreDoc doc : results[ctx.ord].scoreDocs) {
doc.doc += ctx.docBase;

View File

@ -112,7 +112,8 @@ public class TestSegmentToThreadMapping extends LuceneTestCase {
}
@Override
public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs) {
public TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
return null;
}

View File

@ -0,0 +1,140 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field.Store;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.VectorUtil;
public class TestKnnVectorFieldExistsQuery extends LuceneTestCase {
public void testRandom() throws IOException {
int iters = atLeast(10);
for (int iter = 0; iter < iters; ++iter) {
try (Directory dir = newDirectory();
RandomIndexWriter iw = new RandomIndexWriter(random(), dir)) {
int numDocs = atLeast(100);
for (int i = 0; i < numDocs; ++i) {
Document doc = new Document();
boolean hasValue = random().nextBoolean();
if (hasValue) {
doc.add(new KnnVectorField("vector", randomVector(5)));
doc.add(new StringField("has_value", "yes", Store.NO));
}
doc.add(new StringField("field", "value", Store.NO));
iw.addDocument(doc);
}
if (random().nextBoolean()) {
iw.deleteDocuments(new TermQuery(new Term("f", "no")));
}
iw.commit();
try (IndexReader reader = iw.getReader()) {
IndexSearcher searcher = newSearcher(reader);
assertSameMatches(
searcher,
new TermQuery(new Term("has_value", "yes")),
new KnnVectorFieldExistsQuery("vector"),
false);
float boost = random().nextFloat() * 10;
assertSameMatches(
searcher,
new BoostQuery(
new ConstantScoreQuery(new TermQuery(new Term("has_value", "yes"))), boost),
new BoostQuery(new KnnVectorFieldExistsQuery("vector"), boost),
true);
}
}
}
}
public void testMissingField() throws IOException {
try (Directory dir = newDirectory();
RandomIndexWriter iw = new RandomIndexWriter(random(), dir)) {
iw.addDocument(new Document());
iw.commit();
try (IndexReader reader = iw.getReader()) {
IndexSearcher searcher = newSearcher(reader);
assertEquals(0, searcher.count(new KnnVectorFieldExistsQuery("f")));
}
}
}
public void testAllDocsHaveField() throws IOException {
try (Directory dir = newDirectory();
RandomIndexWriter iw = new RandomIndexWriter(random(), dir)) {
Document doc = new Document();
doc.add(new KnnVectorField("vector", randomVector(3)));
iw.addDocument(doc);
iw.commit();
try (IndexReader reader = iw.getReader()) {
IndexSearcher searcher = newSearcher(reader);
assertEquals(1, searcher.count(new KnnVectorFieldExistsQuery("vector")));
}
}
}
public void testFieldExistsButNoDocsHaveField() throws IOException {
try (Directory dir = newDirectory();
RandomIndexWriter iw = new RandomIndexWriter(random(), dir)) {
// 1st segment has the field, but 2nd one does not
Document doc = new Document();
doc.add(new KnnVectorField("vector", randomVector(3)));
iw.addDocument(doc);
iw.commit();
iw.addDocument(new Document());
iw.commit();
try (IndexReader reader = iw.getReader()) {
IndexSearcher searcher = newSearcher(reader);
assertEquals(1, searcher.count(new KnnVectorFieldExistsQuery("vector")));
}
}
}
private float[] randomVector(int dim) {
float[] v = new float[dim];
for (int i = 0; i < dim; i++) {
v[i] = random().nextFloat();
}
VectorUtil.l2normalize(v);
return v;
}
private void assertSameMatches(IndexSearcher searcher, Query q1, Query q2, boolean scores)
throws IOException {
final int maxDoc = searcher.getIndexReader().maxDoc();
final TopDocs td1 = searcher.search(q1, maxDoc, scores ? Sort.RELEVANCE : Sort.INDEXORDER);
final TopDocs td2 = searcher.search(q2, maxDoc, scores ? Sort.RELEVANCE : Sort.INDEXORDER);
assertEquals(td1.totalHits.value, td2.totalHits.value);
for (int i = 0; i < td1.scoreDocs.length; ++i) {
assertEquals(td1.scoreDocs[i].doc, td2.scoreDocs[i].doc);
if (scores) {
assertEquals(td1.scoreDocs[i].score, td2.scoreDocs[i].score, 10e-7);
}
}
}
}

View File

@ -27,7 +27,9 @@ import java.util.HashSet;
import java.util.Set;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.IntPoint;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.FilterDirectoryReader;
@ -36,6 +38,7 @@ import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
@ -91,7 +94,7 @@ public class TestKnnVectorQuery extends LuceneTestCase {
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
KnnVectorQuery kvq = new KnnVectorQuery("field", new float[] {0, 0}, 10);
assertMatches(searcher, kvq, reader.numDocs());
assertMatches(searcher, kvq, 3);
TopDocs topDocs = searcher.search(kvq, 3);
assertEquals(2, topDocs.scoreDocs[0].doc);
assertEquals(0, topDocs.scoreDocs[1].doc);
@ -99,6 +102,33 @@ public class TestKnnVectorQuery extends LuceneTestCase {
}
}
/** Tests that a KnnVectorQuery applies the filter query */
public void testSimpleFilter() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
Query filter = new TermQuery(new Term("id", "id2"));
Query kvq = new KnnVectorQuery("field", new float[] {0, 0}, 10, filter);
TopDocs topDocs = searcher.search(kvq, 3);
assertEquals(1, topDocs.totalHits.value);
assertEquals(2, topDocs.scoreDocs[0].doc);
}
}
public void testFilterWithNoVectorMatches() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
Query filter = new TermQuery(new Term("other", "value"));
Query kvq = new KnnVectorQuery("field", new float[] {0, 0}, 10, filter);
TopDocs topDocs = searcher.search(kvq, 3);
assertEquals(0, topDocs.totalHits.value);
}
}
/** testDimensionMismatch */
public void testDimensionMismatch() throws IOException {
try (Directory indexStore =
@ -455,6 +485,78 @@ public class TestKnnVectorQuery extends LuceneTestCase {
}
}
/** Tests with random vectors and a random filter. Uses RandomIndexWriter. */
public void testRandomWithFilter() throws IOException {
int numDocs = 200;
int dimension = atLeast(5);
int numIters = atLeast(10);
try (Directory d = newDirectory()) {
RandomIndexWriter w = new RandomIndexWriter(random(), d);
for (int i = 0; i < numDocs; i++) {
Document doc = new Document();
doc.add(new KnnVectorField("field", randomVector(dimension)));
doc.add(new NumericDocValuesField("tag", i));
doc.add(new IntPoint("tag", i));
w.addDocument(doc);
}
w.forceMerge(1);
w.close();
try (DirectoryReader reader = DirectoryReader.open(d)) {
IndexSearcher searcher = newSearcher(reader);
for (int i = 0; i < numIters; i++) {
int lower = random().nextInt(50);
// Test a filter with cost less than k and check we use exact search
Query filter1 = IntPoint.newRangeQuery("tag", lower, lower + 8);
TopDocs results =
searcher.search(
new KnnVectorQuery("field", randomVector(dimension), 10, filter1), numDocs);
assertEquals(9, results.totalHits.value);
assertEquals(results.totalHits.value, results.scoreDocs.length);
expectThrows(
UnsupportedOperationException.class,
() ->
searcher.search(
new ThrowingKnnVectorQuery("field", randomVector(dimension), 10, filter1),
numDocs));
// Test a restrictive filter and check we use exact search
Query filter2 = IntPoint.newRangeQuery("tag", lower, lower + 6);
results =
searcher.search(
new KnnVectorQuery("field", randomVector(dimension), 5, filter2), numDocs);
assertEquals(5, results.totalHits.value);
assertEquals(results.totalHits.value, results.scoreDocs.length);
expectThrows(
UnsupportedOperationException.class,
() ->
searcher.search(
new ThrowingKnnVectorQuery("field", randomVector(dimension), 5, filter2),
numDocs));
// Test an unrestrictive filter and check we use approximate search
Query filter3 = IntPoint.newRangeQuery("tag", lower, lower + 150);
results =
searcher.search(
new ThrowingKnnVectorQuery("field", randomVector(dimension), 5, filter3),
numDocs,
new Sort(new SortField("tag", SortField.Type.INT)));
assertEquals(5, results.totalHits.value);
assertEquals(results.totalHits.value, results.scoreDocs.length);
for (ScoreDoc scoreDoc : results.scoreDocs) {
FieldDoc fieldDoc = (FieldDoc) scoreDoc;
assertEquals(1, fieldDoc.fields.length);
int tag = (int) fieldDoc.fields[0];
assertTrue(lower <= tag && tag <= lower + 150);
}
}
}
}
}
public void testDeletes() throws IOException {
try (Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
@ -550,6 +652,7 @@ public class TestKnnVectorQuery extends LuceneTestCase {
}
}
/** Creates a new directory and adds documents with the given vectors as kNN vector fields */
private Directory getIndexStore(String field, float[]... contents) throws IOException {
Directory indexStore = newDirectory();
RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore);
@ -559,6 +662,13 @@ public class TestKnnVectorQuery extends LuceneTestCase {
doc.add(new StringField("id", "id" + i, Field.Store.NO));
writer.addDocument(doc);
}
// Add some documents without a vector
for (int i = 0; i < 5; i++) {
Document doc = new Document();
doc.add(new StringField("other", "value", Field.Store.NO));
writer.addDocument(doc);
}
writer.close();
return indexStore;
}
@ -569,6 +679,23 @@ public class TestKnnVectorQuery extends LuceneTestCase {
assertEquals(expectedMatches, result.length);
}
/**
* A version of {@link KnnVectorQuery} that throws an error when an exact search is run. This
* allows us to check what search strategy is being used.
*/
private static class ThrowingKnnVectorQuery extends KnnVectorQuery {
public ThrowingKnnVectorQuery(String field, float[] target, int k, Query filter) {
super(field, target, k, filter);
}
@Override
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
throws IOException {
throw new UnsupportedOperationException("exact search is not supported");
}
}
private static class NoLiveDocsDirectoryReader extends FilterDirectoryReader {
private NoLiveDocsDirectoryReader(DirectoryReader in) throws IOException {

View File

@ -435,7 +435,8 @@ public class KnnGraphTester {
TopDocs[] results = new TopDocs[reader.leaves().size()];
for (LeafReaderContext ctx : reader.leaves()) {
Bits liveDocs = ctx.reader().getLiveDocs();
results[ctx.ord] = ctx.reader().searchNearestVectors(field, vector, k + fanout, liveDocs);
results[ctx.ord] =
ctx.reader().searchNearestVectors(field, vector, k + fanout, liveDocs, Integer.MAX_VALUE);
int docBase = ctx.docBase;
for (ScoreDoc scoreDoc : results[ctx.ord].scoreDocs) {
scoreDoc.doc += docBase;

View File

@ -45,6 +45,7 @@ import org.apache.lucene.index.VectorValues;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
@ -167,7 +168,8 @@ public class TestHnswGraph extends LuceneTestCase {
vectors.randomAccess(),
VectorSimilarityFunction.DOT_PRODUCT,
hnsw,
null);
null,
Integer.MAX_VALUE);
int[] nodes = nn.nodes();
assertTrue("Number of found results is not equal to [10].", nodes.length == 10);
@ -206,7 +208,8 @@ public class TestHnswGraph extends LuceneTestCase {
vectors.randomAccess(),
VectorSimilarityFunction.DOT_PRODUCT,
hnsw,
acceptOrds);
acceptOrds,
Integer.MAX_VALUE);
int[] nodes = nn.nodes();
assertTrue("Number of found results is not equal to [10].", nodes.length == 10);
int sum = 0;
@ -219,6 +222,38 @@ public class TestHnswGraph extends LuceneTestCase {
assertTrue("sum(result docs)=" + sum, sum < 75);
}
public void testSearchWithSelectiveAcceptOrds() throws IOException {
int nDoc = 100;
int maxConn = 16;
CircularVectorValues vectors = new CircularVectorValues(nDoc);
HnswGraphBuilder builder =
new HnswGraphBuilder(
vectors, VectorSimilarityFunction.DOT_PRODUCT, maxConn, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors);
// Only mark a few vectors as accepted
BitSet acceptOrds = new FixedBitSet(vectors.size);
for (int i = 0; i < vectors.size; i += random().nextInt(15, 20)) {
acceptOrds.set(i);
}
// Check the search finds all accepted vectors
int numAccepted = acceptOrds.cardinality();
NeighborQueue nn =
HnswGraphSearcher.search(
new float[] {1, 0},
numAccepted,
vectors.randomAccess(),
VectorSimilarityFunction.DOT_PRODUCT,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
int[] nodes = nn.nodes();
assertEquals(numAccepted, nodes.length);
for (int node : nodes) {
assertTrue("the results include a deleted document: " + node, acceptOrds.get(node));
}
}
public void testSearchWithSkewedAcceptOrds() throws IOException {
int nDoc = 1000;
CircularVectorValues vectors = new CircularVectorValues(nDoc);
@ -239,7 +274,8 @@ public class TestHnswGraph extends LuceneTestCase {
vectors.randomAccess(),
VectorSimilarityFunction.EUCLIDEAN,
hnsw,
acceptOrds);
acceptOrds,
Integer.MAX_VALUE);
int[] nodes = nn.nodes();
assertTrue("Number of found results is not equal to [10].", nodes.length == 10);
int sum = 0;
@ -252,6 +288,31 @@ public class TestHnswGraph extends LuceneTestCase {
assertTrue("sum(result docs)=" + sum, sum < 5100);
}
public void testVisitedLimit() throws IOException {
int nDoc = 500;
int maxConn = 16;
CircularVectorValues vectors = new CircularVectorValues(nDoc);
HnswGraphBuilder builder =
new HnswGraphBuilder(
vectors, VectorSimilarityFunction.DOT_PRODUCT, maxConn, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors);
int topK = 50;
int visitedLimit = topK + random().nextInt(5);
NeighborQueue nn =
HnswGraphSearcher.search(
new float[] {1, 0},
topK,
vectors.randomAccess(),
VectorSimilarityFunction.DOT_PRODUCT,
hnsw,
createRandomAcceptOrds(0, vectors.size),
visitedLimit);
assertTrue(nn.incomplete());
// The visited count shouldn't be much over the limit
assertTrue(nn.visitedCount() < visitedLimit + 3);
}
public void testBoundsCheckerMax() {
BoundsChecker max = BoundsChecker.create(false);
float f = random().nextFloat() - 0.5f;
@ -382,7 +443,8 @@ public class TestHnswGraph extends LuceneTestCase {
for (int i = 0; i < 100; i++) {
float[] query = randomVector(random(), dim);
NeighborQueue actual =
HnswGraphSearcher.search(query, 100, vectors, similarityFunction, hnsw, acceptOrds);
HnswGraphSearcher.search(
query, 100, vectors, similarityFunction, hnsw, acceptOrds, Integer.MAX_VALUE);
while (actual.size() > topK) {
actual.pop();
}

View File

@ -161,7 +161,8 @@ public class TermVectorLeafReader extends LeafReader {
}
@Override
public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs) {
public TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
return null;
}

View File

@ -1351,7 +1351,8 @@ public class MemoryIndex {
}
@Override
public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs) {
public TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
return null;
}

View File

@ -114,10 +114,11 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
}
@Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs) throws IOException {
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
FieldInfo fi = fis.fieldInfo(field);
assert fi != null && fi.getVectorDimension() > 0;
TopDocs hits = delegate.search(field, target, k, acceptDocs);
TopDocs hits = delegate.search(field, target, k, acceptDocs, visitedLimit);
assert hits != null;
assert hits.scoreDocs.length <= k;
return hits;

View File

@ -563,7 +563,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
// assert that knn search doesn't fail on a field with all deleted docs
TopDocs results =
leafReader.searchNearestVectors("v", randomVector(3), 1, leafReader.getLiveDocs());
leafReader.searchNearestVectors(
"v", randomVector(3), 1, leafReader.getLiveDocs(), Integer.MAX_VALUE);
assertEquals(0, results.scoreDocs.length);
}
}
@ -887,7 +888,9 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
k = numLiveDocsWithVectors;
}
TopDocs results =
ctx.reader().searchNearestVectors(fieldName, randomVector(dimension), k, liveDocs);
ctx.reader()
.searchNearestVectors(
fieldName, randomVector(dimension), k, liveDocs, Integer.MAX_VALUE);
assertEquals(Math.min(k, size), results.scoreDocs.length);
for (int i = 0; i < k - 1; i++) {
assertTrue(results.scoreDocs[i].score >= results.scoreDocs[i + 1].score);

View File

@ -223,9 +223,9 @@ class MergeReaderWrapper extends LeafReader {
}
@Override
public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs)
throws IOException {
return in.searchNearestVectors(field, target, k, acceptDocs);
public TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
}
@Override

View File

@ -228,7 +228,8 @@ public class QueryUtils {
}
@Override
public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs) {
public TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
return null;
}