mirror of https://github.com/apache/lucene.git
LUCENE-10382: Support filtering in KnnVectorQuery (#656)
This PR adds support for a query filter in KnnVectorQuery. First, we gather the query results for each leaf as a bit set. Then the HNSW search skips over the non-matching documents (using the same approach as for live docs). To prevent HNSW search from visiting too many documents when the filter is very selective, we short-circuit if HNSW has already visited more than the number of documents that match the filter, and execute an exact search instead. This bounds the number of visited documents at roughly 2x the cost of just running the exact filter, while in most cases HNSW completes successfully and does a lot better. Co-authored-by: Joel Bernstein <jbernste@apache.org>
This commit is contained in:
parent
c132bbf677
commit
8ca372573d
|
@ -158,6 +158,9 @@ New Features
|
|||
|
||||
* LUCENE-10415: FunctionScoreQuery and IndexOrDocValuesQuery delegate Weight#count. (Ignacio Vera)
|
||||
|
||||
* LUCENE-10382: Add support for filtering in KnnVectorQuery. This allows for finding the
|
||||
nearest k documents that also match a query. (Julie Tibshirani, Joel Bernstein)
|
||||
|
||||
Improvements
|
||||
---------------------
|
||||
|
||||
|
|
|
@ -229,7 +229,8 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs) throws IOException {
|
||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
|
||||
if (fieldEntry.size() == 0) {
|
||||
|
|
|
@ -144,7 +144,8 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs) throws IOException {
|
||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
VectorValues values = getVectorValues(field);
|
||||
if (target.length != values.dimension()) {
|
||||
throw new IllegalArgumentException(
|
||||
|
|
|
@ -100,7 +100,8 @@ public abstract class KnnVectorsFormat implements NamedSPILoader.NamedSPI {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs) {
|
||||
public TopDocs search(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
return TopDocsCollector.EMPTY_TOPDOCS;
|
||||
}
|
||||
|
||||
|
|
|
@ -21,7 +21,9 @@ import java.io.Closeable;
|
|||
import java.io.IOException;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
||||
|
@ -58,6 +60,12 @@ public abstract class KnnVectorsReader implements Closeable, Accountable {
|
|||
* true k closest neighbors. For large values of k (for example when k is close to the total
|
||||
* number of documents), the search may also retrieve fewer than k documents.
|
||||
*
|
||||
* <p>The returned {@link TopDocs} will contain a {@link ScoreDoc} for each nearest neighbor, in
|
||||
* order of their similarity to the query vector (decreasing scores). The {@link TotalHits}
|
||||
* contains the number of documents visited during the search. If the search stopped early because
|
||||
* it hit {@code visitedLimit}, it is indicated through the relation {@code
|
||||
* TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO}.
|
||||
*
|
||||
* <p>The behavior is undefined if the given field doesn't have KNN vectors enabled on its {@link
|
||||
* FieldInfo}. The return value is never {@code null}.
|
||||
*
|
||||
|
@ -66,10 +74,11 @@ public abstract class KnnVectorsReader implements Closeable, Accountable {
|
|||
* @param k the number of docs to return
|
||||
* @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null}
|
||||
* if they are all allowed to match.
|
||||
* @param visitedLimit the maximum number of nodes that the search is allowed to visit
|
||||
* @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores.
|
||||
*/
|
||||
public abstract TopDocs search(String field, float[] target, int k, Bits acceptDocs)
|
||||
throws IOException;
|
||||
public abstract TopDocs search(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
|
||||
|
||||
/**
|
||||
* Returns an instance optimized for merging. This instance may only be consumed in the thread
|
||||
|
|
|
@ -86,7 +86,8 @@ public abstract class KnnVectorsWriter implements Closeable {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs) {
|
||||
public TopDocs search(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
});
|
||||
|
|
|
@ -219,7 +219,8 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs) throws IOException {
|
||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
|
||||
if (fieldEntry.size() == 0) {
|
||||
|
@ -228,8 +229,8 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
|||
|
||||
// bound k by total number of vectors to prevent oversizing data structures
|
||||
k = Math.min(k, fieldEntry.size());
|
||||
|
||||
OffHeapVectorValues vectorValues = getOffHeapVectorValues(fieldEntry);
|
||||
|
||||
NeighborQueue results =
|
||||
HnswGraphSearcher.search(
|
||||
target,
|
||||
|
@ -237,7 +238,8 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
|||
vectorValues,
|
||||
fieldEntry.similarityFunction,
|
||||
getGraph(fieldEntry),
|
||||
getAcceptOrds(acceptDocs, fieldEntry));
|
||||
getAcceptOrds(acceptDocs, fieldEntry),
|
||||
visitedLimit);
|
||||
|
||||
int i = 0;
|
||||
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
|
||||
|
@ -247,11 +249,12 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
|||
results.pop();
|
||||
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc(node), score);
|
||||
}
|
||||
// always return >= the case where we can assert == is only when there are fewer than topK
|
||||
// vectors in the index
|
||||
return new TopDocs(
|
||||
new TotalHits(results.visitedCount(), TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO),
|
||||
scoreDocs);
|
||||
|
||||
TotalHits.Relation relation =
|
||||
results.incomplete()
|
||||
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
|
||||
: TotalHits.Relation.EQUAL_TO;
|
||||
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
|
||||
}
|
||||
|
||||
private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException {
|
||||
|
|
|
@ -263,12 +263,13 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs) throws IOException {
|
||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
KnnVectorsReader knnVectorsReader = fields.get(field);
|
||||
if (knnVectorsReader == null) {
|
||||
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
|
||||
} else {
|
||||
return knnVectorsReader.search(field, target, k, acceptDocs);
|
||||
return knnVectorsReader.search(field, target, k, acceptDocs, visitedLimit);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -223,8 +223,8 @@ public abstract class CodecReader extends LeafReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public final TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs)
|
||||
throws IOException {
|
||||
public final TopDocs searchNearestVectors(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
|
||||
ensureOpen();
|
||||
FieldInfo fi = getFieldInfos().fieldInfo(field);
|
||||
if (fi == null || fi.getVectorDimension() == 0) {
|
||||
|
@ -232,7 +232,7 @@ public abstract class CodecReader extends LeafReader {
|
|||
return null;
|
||||
}
|
||||
|
||||
return getVectorReader().search(field, target, k, acceptDocs);
|
||||
return getVectorReader().search(field, target, k, acceptDocs, visitedLimit);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -53,8 +53,8 @@ abstract class DocValuesLeafReader extends LeafReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs)
|
||||
throws IOException {
|
||||
public TopDocs searchNearestVectors(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
|
|
|
@ -352,9 +352,9 @@ public abstract class FilterLeafReader extends LeafReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs)
|
||||
throws IOException {
|
||||
return in.searchNearestVectors(field, target, k, acceptDocs);
|
||||
public TopDocs searchNearestVectors(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
|
||||
return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -17,7 +17,9 @@
|
|||
package org.apache.lucene.index;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
||||
/**
|
||||
|
@ -207,20 +209,31 @@ public abstract class LeafReader extends IndexReader {
|
|||
|
||||
/**
|
||||
* Return the k nearest neighbor documents as determined by comparison of their vector values for
|
||||
* this field, to the given vector, by the field's search strategy. If the search strategy is
|
||||
* reversed, lower values indicate nearer vectors, otherwise higher scores indicate nearer
|
||||
* vectors. Unlike relevance scores, vector scores may be negative.
|
||||
* this field, to the given vector, by the field's similarity function. The score of each document
|
||||
* is derived from the vector similarity in a way that ensures scores are positive and that a
|
||||
* larger score corresponds to a higher ranking.
|
||||
*
|
||||
* <p>The search is allowed to be approximate, meaning the results are not guaranteed to be the
|
||||
* true k closest neighbors. For large values of k (for example when k is close to the total
|
||||
* number of documents), the search may also retrieve fewer than k documents.
|
||||
*
|
||||
* <p>The returned {@link TopDocs} will contain a {@link ScoreDoc} for each nearest neighbor,
|
||||
* sorted in order of their similarity to the query vector (decreasing scores). The {@link
|
||||
* TotalHits} contains the number of documents visited during the search. If the search stopped
|
||||
* early because it hit {@code visitedLimit}, it is indicated through the relation {@code
|
||||
* TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO}.
|
||||
*
|
||||
* @param field the vector field to search
|
||||
* @param target the vector-valued query
|
||||
* @param k the number of docs to return
|
||||
* @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null}
|
||||
* if they are all allowed to match.
|
||||
* @param visitedLimit the maximum number of nodes that the search is allowed to visit
|
||||
* @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores.
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public abstract TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs)
|
||||
throws IOException;
|
||||
public abstract TopDocs searchNearestVectors(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
|
||||
|
||||
/**
|
||||
* Get the {@link FieldInfos} describing all fields in this reader.
|
||||
|
|
|
@ -393,11 +393,14 @@ public class ParallelLeafReader extends LeafReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(String fieldName, float[] target, int k, Bits acceptDocs)
|
||||
public TopDocs searchNearestVectors(
|
||||
String fieldName, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
ensureOpen();
|
||||
LeafReader reader = fieldToReader.get(fieldName);
|
||||
return reader == null ? null : reader.searchNearestVectors(fieldName, target, k, acceptDocs);
|
||||
return reader == null
|
||||
? null
|
||||
: reader.searchNearestVectors(fieldName, target, k, acceptDocs, visitedLimit);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -167,9 +167,9 @@ public final class SlowCodecReaderWrapper {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs)
|
||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
return reader.searchNearestVectors(field, target, k, acceptDocs);
|
||||
return reader.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -384,7 +384,8 @@ public final class SortingCodecReader extends FilterCodecReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs) {
|
||||
public TopDocs search(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ import org.apache.lucene.util.BytesRef;
|
|||
public abstract class VectorValues extends DocIdSetIterator {
|
||||
|
||||
/** The maximum length of a vector */
|
||||
public static int MAX_DIMENSIONS = 1024;
|
||||
public static final int MAX_DIMENSIONS = 1024;
|
||||
|
||||
/** Sole constructor */
|
||||
protected VectorValues() {}
|
||||
|
|
|
@ -135,7 +135,8 @@ class VectorValuesWriter {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs)
|
||||
public TopDocs search(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.search;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
|
||||
/**
|
||||
* A {@link Query} that matches documents that contain a {@link
|
||||
* org.apache.lucene.document.KnnVectorField}.
|
||||
*/
|
||||
public class KnnVectorFieldExistsQuery extends Query {
|
||||
|
||||
private final String field;
|
||||
|
||||
/** Create a query that will match documents which have a value for the given {@code field}. */
|
||||
public KnnVectorFieldExistsQuery(String field) {
|
||||
this.field = Objects.requireNonNull(field);
|
||||
}
|
||||
|
||||
public String getField() {
|
||||
return field;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object other) {
|
||||
return sameClassAs(other) && field.equals(((KnnVectorFieldExistsQuery) other).field);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return 31 * classHash() + field.hashCode();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString(String field) {
|
||||
return "KnnVectorFieldExistsQuery [field=" + this.field + "]";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(QueryVisitor visitor) {
|
||||
if (visitor.acceptField(field)) {
|
||||
visitor.visitLeaf(this);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) {
|
||||
return new ConstantScoreWeight(this, boost) {
|
||||
@Override
|
||||
public Scorer scorer(LeafReaderContext context) throws IOException {
|
||||
DocIdSetIterator iterator = context.reader().getVectorValues(field);
|
||||
if (iterator == null) {
|
||||
return null;
|
||||
}
|
||||
return new ConstantScoreScorer(this, score(), scoreMode, iterator);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isCacheable(LeafReaderContext ctx) {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
|
@ -24,19 +24,36 @@ import java.util.Comparator;
|
|||
import java.util.Objects;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.document.KnnVectorField;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.BitSetIterator;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
|
||||
/** Uses {@link KnnVectorsReader#search} to perform nearest neighbour search. */
|
||||
/**
|
||||
* Uses {@link KnnVectorsReader#search} to perform nearest neighbour search.
|
||||
*
|
||||
* <p>This query also allows for performing a kNN search subject to a filter. In this case, it first
|
||||
* executes the filter for each leaf, then chooses a strategy dynamically:
|
||||
*
|
||||
* <ul>
|
||||
* <li>If the filter cost is less than k, just execute an exact search
|
||||
* <li>Otherwise run a kNN search subject to the filter
|
||||
* <li>If the kNN search visits too many vectors without completing, stop and run an exact search
|
||||
* </ul>
|
||||
*/
|
||||
public class KnnVectorQuery extends Query {
|
||||
|
||||
private static final TopDocs NO_RESULTS =
|
||||
new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
|
||||
private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
|
||||
|
||||
private final String field;
|
||||
private final float[] target;
|
||||
private final int k;
|
||||
private final Query filter;
|
||||
|
||||
/**
|
||||
* Find the <code>k</code> nearest documents to the target vector according to the vectors in the
|
||||
|
@ -48,19 +65,53 @@ public class KnnVectorQuery extends Query {
|
|||
* @throws IllegalArgumentException if <code>k</code> is less than 1
|
||||
*/
|
||||
public KnnVectorQuery(String field, float[] target, int k) {
|
||||
this(field, target, k, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the <code>k</code> nearest documents to the target vector according to the vectors in the
|
||||
* given field. <code>target</code> vector.
|
||||
*
|
||||
* @param field a field that has been indexed as a {@link KnnVectorField}.
|
||||
* @param target the target of the search
|
||||
* @param k the number of documents to find
|
||||
* @param filter a filter applied before the vector search
|
||||
* @throws IllegalArgumentException if <code>k</code> is less than 1
|
||||
*/
|
||||
public KnnVectorQuery(String field, float[] target, int k, Query filter) {
|
||||
this.field = field;
|
||||
this.target = target;
|
||||
this.k = k;
|
||||
if (k < 1) {
|
||||
throw new IllegalArgumentException("k must be at least 1, got: " + k);
|
||||
}
|
||||
this.filter = filter;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Query rewrite(IndexReader reader) throws IOException {
|
||||
TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
|
||||
|
||||
BitSetCollector filterCollector = null;
|
||||
if (filter != null) {
|
||||
filterCollector = new BitSetCollector(reader.leaves().size());
|
||||
IndexSearcher indexSearcher = new IndexSearcher(reader);
|
||||
BooleanQuery booleanQuery =
|
||||
new BooleanQuery.Builder()
|
||||
.add(filter, BooleanClause.Occur.FILTER)
|
||||
.add(new KnnVectorFieldExistsQuery(field), BooleanClause.Occur.FILTER)
|
||||
.build();
|
||||
indexSearcher.search(booleanQuery, filterCollector);
|
||||
}
|
||||
|
||||
for (LeafReaderContext ctx : reader.leaves()) {
|
||||
perLeafResults[ctx.ord] = searchLeaf(ctx, k);
|
||||
TopDocs results = searchLeaf(ctx, filterCollector);
|
||||
if (ctx.docBase > 0) {
|
||||
for (ScoreDoc scoreDoc : results.scoreDocs) {
|
||||
scoreDoc.doc += ctx.docBase;
|
||||
}
|
||||
}
|
||||
perLeafResults[ctx.ord] = results;
|
||||
}
|
||||
// Merge sort the results
|
||||
TopDocs topK = TopDocs.merge(k, perLeafResults);
|
||||
|
@ -70,18 +121,126 @@ public class KnnVectorQuery extends Query {
|
|||
return createRewrittenQuery(reader, topK);
|
||||
}
|
||||
|
||||
private TopDocs searchLeaf(LeafReaderContext ctx, int kPerLeaf) throws IOException {
|
||||
Bits liveDocs = ctx.reader().getLiveDocs();
|
||||
TopDocs results = ctx.reader().searchNearestVectors(field, target, kPerLeaf, liveDocs);
|
||||
if (results == null) {
|
||||
private TopDocs searchLeaf(LeafReaderContext ctx, BitSetCollector filterCollector)
|
||||
throws IOException {
|
||||
|
||||
if (filterCollector == null) {
|
||||
Bits acceptDocs = ctx.reader().getLiveDocs();
|
||||
return approximateSearch(ctx, acceptDocs, Integer.MAX_VALUE);
|
||||
} else {
|
||||
BitSetIterator filterIterator = filterCollector.getIterator(ctx.ord);
|
||||
if (filterIterator == null || filterIterator.cost() == 0) {
|
||||
return NO_RESULTS;
|
||||
}
|
||||
if (ctx.docBase > 0) {
|
||||
for (ScoreDoc scoreDoc : results.scoreDocs) {
|
||||
scoreDoc.doc += ctx.docBase;
|
||||
}
|
||||
|
||||
if (filterIterator.cost() <= k) {
|
||||
// If there are <= k possible matches, short-circuit and perform exact search, since HNSW
|
||||
// must always visit at least k documents
|
||||
return exactSearch(ctx, filterIterator);
|
||||
}
|
||||
|
||||
// Perform the approximate kNN search
|
||||
Bits acceptDocs =
|
||||
filterIterator.getBitSet(); // The filter iterator already incorporates live docs
|
||||
int visitedLimit = (int) filterIterator.cost();
|
||||
TopDocs results = approximateSearch(ctx, acceptDocs, visitedLimit);
|
||||
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) {
|
||||
return results;
|
||||
} else {
|
||||
// We stopped the kNN search because it visited too many nodes, so fall back to exact search
|
||||
return exactSearch(ctx, filterIterator);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
TopDocs results =
|
||||
context.reader().searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
|
||||
return results != null ? results : NO_RESULTS;
|
||||
}
|
||||
|
||||
// We allow this to be overridden so that tests can check what search strategy is used
|
||||
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
|
||||
throws IOException {
|
||||
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
|
||||
if (fi == null || fi.getVectorDimension() == 0) {
|
||||
// The field does not exist or does not index vectors
|
||||
return NO_RESULTS;
|
||||
}
|
||||
|
||||
VectorSimilarityFunction similarityFunction = fi.getVectorSimilarityFunction();
|
||||
VectorValues vectorValues = context.reader().getVectorValues(field);
|
||||
|
||||
HitQueue queue = new HitQueue(k, true);
|
||||
ScoreDoc topDoc = queue.top();
|
||||
int doc;
|
||||
while ((doc = acceptIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
|
||||
int vectorDoc = vectorValues.advance(doc);
|
||||
assert vectorDoc == doc;
|
||||
float[] vector = vectorValues.vectorValue();
|
||||
|
||||
float score = similarityFunction.convertToScore(similarityFunction.compare(vector, target));
|
||||
if (score >= topDoc.score) {
|
||||
topDoc.score = score;
|
||||
topDoc.doc = doc;
|
||||
topDoc = queue.updateTop();
|
||||
}
|
||||
}
|
||||
|
||||
// Remove any remaining sentinel values
|
||||
while (queue.size() > 0 && queue.top().score < 0) {
|
||||
queue.pop();
|
||||
}
|
||||
|
||||
ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()];
|
||||
for (int i = topScoreDocs.length - 1; i >= 0; i--) {
|
||||
topScoreDocs[i] = queue.pop();
|
||||
}
|
||||
|
||||
TotalHits totalHits = new TotalHits(acceptIterator.cost(), TotalHits.Relation.EQUAL_TO);
|
||||
return new TopDocs(totalHits, topScoreDocs);
|
||||
}
|
||||
|
||||
private static class BitSetCollector extends SimpleCollector {
|
||||
|
||||
private final BitSet[] bitSets;
|
||||
private final int[] cost;
|
||||
private int ord;
|
||||
|
||||
private BitSetCollector(int numLeaves) {
|
||||
this.bitSets = new BitSet[numLeaves];
|
||||
this.cost = new int[bitSets.length];
|
||||
}
|
||||
|
||||
/**
|
||||
* Return an iterator whose {@link BitSet} contains the matching documents, and whose {@link
|
||||
* BitSetIterator#cost()} is the exact cardinality. If the leaf was never visited, then return
|
||||
* null.
|
||||
*/
|
||||
public BitSetIterator getIterator(int contextOrd) {
|
||||
if (bitSets[contextOrd] == null) {
|
||||
return null;
|
||||
}
|
||||
return new BitSetIterator(bitSets[contextOrd], cost[contextOrd]);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void collect(int doc) throws IOException {
|
||||
bitSets[ord].set(doc);
|
||||
cost[ord]++;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doSetNextReader(LeafReaderContext context) throws IOException {
|
||||
bitSets[context.ord] = new FixedBitSet(context.reader().maxDoc());
|
||||
ord = context.ord;
|
||||
}
|
||||
|
||||
@Override
|
||||
public org.apache.lucene.search.ScoreMode scoreMode() {
|
||||
return org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES;
|
||||
}
|
||||
}
|
||||
|
||||
private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
|
||||
|
|
|
@ -151,13 +151,12 @@ public final class HnswGraphBuilder {
|
|||
|
||||
// for levels > nodeLevel search with topk = 1
|
||||
for (int level = curMaxLevel; level > nodeLevel; level--) {
|
||||
candidates = graphSearcher.searchLevel(value, 1, level, eps, vectorValues, hnsw, null);
|
||||
candidates = graphSearcher.searchLevel(value, 1, level, eps, vectorValues, hnsw);
|
||||
eps = new int[] {candidates.pop()};
|
||||
}
|
||||
// for levels <= nodeLevel search with topk = beamWidth, and add connections
|
||||
for (int level = Math.min(nodeLevel, curMaxLevel); level >= 0; level--) {
|
||||
candidates =
|
||||
graphSearcher.searchLevel(value, beamWidth, level, eps, vectorValues, hnsw, null);
|
||||
candidates = graphSearcher.searchLevel(value, beamWidth, level, eps, vectorValues, hnsw);
|
||||
eps = candidates.nodes();
|
||||
hnsw.addNode(level, node);
|
||||
addDiverseNeighbors(level, node, candidates);
|
||||
|
|
|
@ -65,6 +65,7 @@ public final class HnswGraphSearcher {
|
|||
* graph.
|
||||
* @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or
|
||||
* {@code null} if they are all allowed to match.
|
||||
* @param visitedLimit the maximum number of nodes that the search is allowed to visit
|
||||
* @return a priority queue holding the closest neighbors found
|
||||
*/
|
||||
public static NeighborQueue search(
|
||||
|
@ -73,7 +74,8 @@ public final class HnswGraphSearcher {
|
|||
RandomAccessVectorValues vectors,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
HnswGraph graph,
|
||||
Bits acceptOrds)
|
||||
Bits acceptOrds,
|
||||
int visitedLimit)
|
||||
throws IOException {
|
||||
HnswGraphSearcher graphSearcher =
|
||||
new HnswGraphSearcher(
|
||||
|
@ -82,16 +84,25 @@ public final class HnswGraphSearcher {
|
|||
new SparseFixedBitSet(vectors.size()));
|
||||
NeighborQueue results;
|
||||
int[] eps = new int[] {graph.entryNode()};
|
||||
int numVisited = 0;
|
||||
for (int level = graph.numLevels() - 1; level >= 1; level--) {
|
||||
results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null);
|
||||
results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null, visitedLimit);
|
||||
eps[0] = results.pop();
|
||||
|
||||
numVisited += results.visitedCount();
|
||||
visitedLimit -= results.visitedCount();
|
||||
}
|
||||
results = graphSearcher.searchLevel(query, topK, 0, eps, vectors, graph, acceptOrds);
|
||||
results =
|
||||
graphSearcher.searchLevel(query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit);
|
||||
results.setVisitedCount(results.visitedCount() + numVisited);
|
||||
return results;
|
||||
}
|
||||
|
||||
/**
|
||||
* Searches for the nearest neighbors of a query vector in a given level
|
||||
* Searches for the nearest neighbors of a query vector in a given level.
|
||||
*
|
||||
* <p>If the search stops early because it reaches the visited nodes limit, then the results will
|
||||
* be marked incomplete through {@link NeighborQueue#incomplete()}.
|
||||
*
|
||||
* @param query search query vector
|
||||
* @param topK the number of nearest to query results to return
|
||||
|
@ -99,23 +110,34 @@ public final class HnswGraphSearcher {
|
|||
* @param eps the entry points for search at this level expressed as level 0th ordinals
|
||||
* @param vectors vector values
|
||||
* @param graph the graph values
|
||||
* @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or
|
||||
* {@code null} if they are all allowed to match.
|
||||
* @return a priority queue holding the closest neighbors found
|
||||
*/
|
||||
NeighborQueue searchLevel(
|
||||
float[] query,
|
||||
int topK,
|
||||
int level,
|
||||
final int[] eps,
|
||||
RandomAccessVectorValues vectors,
|
||||
HnswGraph graph)
|
||||
throws IOException {
|
||||
return searchLevel(query, topK, level, eps, vectors, graph, null, Integer.MAX_VALUE);
|
||||
}
|
||||
|
||||
private NeighborQueue searchLevel(
|
||||
float[] query,
|
||||
int topK,
|
||||
int level,
|
||||
final int[] eps,
|
||||
RandomAccessVectorValues vectors,
|
||||
HnswGraph graph,
|
||||
Bits acceptOrds)
|
||||
Bits acceptOrds,
|
||||
int visitedLimit)
|
||||
throws IOException {
|
||||
int size = graph.size();
|
||||
NeighborQueue results = new NeighborQueue(topK, similarityFunction.reversed);
|
||||
clearScratchState();
|
||||
|
||||
int numVisited = 0;
|
||||
for (int ep : eps) {
|
||||
if (visited.getAndSet(ep) == false) {
|
||||
float score = similarityFunction.compare(query, vectors.vectorValue(ep));
|
||||
|
@ -123,6 +145,7 @@ public final class HnswGraphSearcher {
|
|||
if (acceptOrds == null || acceptOrds.get(ep)) {
|
||||
results.add(ep, score);
|
||||
}
|
||||
numVisited++;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -138,6 +161,12 @@ public final class HnswGraphSearcher {
|
|||
if (bound.check(topCandidateScore)) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (numVisited >= visitedLimit) {
|
||||
results.markIncomplete();
|
||||
break;
|
||||
}
|
||||
|
||||
int topCandidateNode = candidates.pop();
|
||||
graph.seek(level, topCandidateNode);
|
||||
int friendOrd;
|
||||
|
@ -148,6 +177,7 @@ public final class HnswGraphSearcher {
|
|||
}
|
||||
|
||||
float score = similarityFunction.compare(query, vectors.vectorValue(friendOrd));
|
||||
numVisited++;
|
||||
if (bound.check(score) == false) {
|
||||
candidates.add(friendOrd, score);
|
||||
if (acceptOrds == null || acceptOrds.get(friendOrd)) {
|
||||
|
@ -161,7 +191,7 @@ public final class HnswGraphSearcher {
|
|||
while (results.size() > topK) {
|
||||
results.pop();
|
||||
}
|
||||
results.setVisitedCount(visited.approximateCardinality());
|
||||
results.setVisitedCount(numVisited);
|
||||
return results;
|
||||
}
|
||||
|
||||
|
|
|
@ -53,6 +53,8 @@ public class NeighborQueue {
|
|||
|
||||
// Used to track the number of neighbors visited during a single graph traversal
|
||||
private int visitedCount;
|
||||
// Whether the search stopped early because it reached the visited nodes limit
|
||||
private boolean incomplete;
|
||||
|
||||
public NeighborQueue(int initialSize, boolean reversed) {
|
||||
this.heap = new LongHeap(initialSize);
|
||||
|
@ -128,6 +130,14 @@ public class NeighborQueue {
|
|||
this.visitedCount = visitedCount;
|
||||
}
|
||||
|
||||
public boolean incomplete() {
|
||||
return incomplete;
|
||||
}
|
||||
|
||||
public void markIncomplete() {
|
||||
this.incomplete = true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "Neighbors[" + heap.size() + "]";
|
||||
|
|
|
@ -104,11 +104,13 @@ public class TestPerFieldKnnVectorsFormat extends BaseKnnVectorsFormatTestCase {
|
|||
try (IndexReader ireader = DirectoryReader.open(directory)) {
|
||||
LeafReader reader = ireader.leaves().get(0).reader();
|
||||
TopDocs hits1 =
|
||||
reader.searchNearestVectors("field1", new float[] {1, 2, 3}, 10, reader.getLiveDocs());
|
||||
reader.searchNearestVectors(
|
||||
"field1", new float[] {1, 2, 3}, 10, reader.getLiveDocs(), Integer.MAX_VALUE);
|
||||
assertEquals(1, hits1.scoreDocs.length);
|
||||
|
||||
TopDocs hits2 =
|
||||
reader.searchNearestVectors("field2", new float[] {1, 2, 3}, 10, reader.getLiveDocs());
|
||||
reader.searchNearestVectors(
|
||||
"field2", new float[] {1, 2, 3}, 10, reader.getLiveDocs(), Integer.MAX_VALUE);
|
||||
assertEquals(1, hits2.scoreDocs.length);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -397,7 +397,9 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
TopDocs[] results = new TopDocs[reader.leaves().size()];
|
||||
for (LeafReaderContext ctx : reader.leaves()) {
|
||||
Bits liveDocs = ctx.reader().getLiveDocs();
|
||||
results[ctx.ord] = ctx.reader().searchNearestVectors(KNN_GRAPH_FIELD, vector, k, liveDocs);
|
||||
results[ctx.ord] =
|
||||
ctx.reader()
|
||||
.searchNearestVectors(KNN_GRAPH_FIELD, vector, k, liveDocs, Integer.MAX_VALUE);
|
||||
if (ctx.docBase > 0) {
|
||||
for (ScoreDoc doc : results[ctx.ord].scoreDocs) {
|
||||
doc.doc += ctx.docBase;
|
||||
|
|
|
@ -112,7 +112,8 @@ public class TestSegmentToThreadMapping extends LuceneTestCase {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs) {
|
||||
public TopDocs searchNearestVectors(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,140 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.search;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field.Store;
|
||||
import org.apache.lucene.document.KnnVectorField;
|
||||
import org.apache.lucene.document.StringField;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
|
||||
public class TestKnnVectorFieldExistsQuery extends LuceneTestCase {
|
||||
|
||||
public void testRandom() throws IOException {
|
||||
int iters = atLeast(10);
|
||||
for (int iter = 0; iter < iters; ++iter) {
|
||||
try (Directory dir = newDirectory();
|
||||
RandomIndexWriter iw = new RandomIndexWriter(random(), dir)) {
|
||||
int numDocs = atLeast(100);
|
||||
for (int i = 0; i < numDocs; ++i) {
|
||||
Document doc = new Document();
|
||||
boolean hasValue = random().nextBoolean();
|
||||
if (hasValue) {
|
||||
doc.add(new KnnVectorField("vector", randomVector(5)));
|
||||
doc.add(new StringField("has_value", "yes", Store.NO));
|
||||
}
|
||||
doc.add(new StringField("field", "value", Store.NO));
|
||||
iw.addDocument(doc);
|
||||
}
|
||||
if (random().nextBoolean()) {
|
||||
iw.deleteDocuments(new TermQuery(new Term("f", "no")));
|
||||
}
|
||||
iw.commit();
|
||||
|
||||
try (IndexReader reader = iw.getReader()) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
|
||||
assertSameMatches(
|
||||
searcher,
|
||||
new TermQuery(new Term("has_value", "yes")),
|
||||
new KnnVectorFieldExistsQuery("vector"),
|
||||
false);
|
||||
|
||||
float boost = random().nextFloat() * 10;
|
||||
assertSameMatches(
|
||||
searcher,
|
||||
new BoostQuery(
|
||||
new ConstantScoreQuery(new TermQuery(new Term("has_value", "yes"))), boost),
|
||||
new BoostQuery(new KnnVectorFieldExistsQuery("vector"), boost),
|
||||
true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testMissingField() throws IOException {
|
||||
try (Directory dir = newDirectory();
|
||||
RandomIndexWriter iw = new RandomIndexWriter(random(), dir)) {
|
||||
iw.addDocument(new Document());
|
||||
iw.commit();
|
||||
try (IndexReader reader = iw.getReader()) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
assertEquals(0, searcher.count(new KnnVectorFieldExistsQuery("f")));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testAllDocsHaveField() throws IOException {
|
||||
try (Directory dir = newDirectory();
|
||||
RandomIndexWriter iw = new RandomIndexWriter(random(), dir)) {
|
||||
Document doc = new Document();
|
||||
doc.add(new KnnVectorField("vector", randomVector(3)));
|
||||
iw.addDocument(doc);
|
||||
iw.commit();
|
||||
try (IndexReader reader = iw.getReader()) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
assertEquals(1, searcher.count(new KnnVectorFieldExistsQuery("vector")));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testFieldExistsButNoDocsHaveField() throws IOException {
|
||||
try (Directory dir = newDirectory();
|
||||
RandomIndexWriter iw = new RandomIndexWriter(random(), dir)) {
|
||||
// 1st segment has the field, but 2nd one does not
|
||||
Document doc = new Document();
|
||||
doc.add(new KnnVectorField("vector", randomVector(3)));
|
||||
iw.addDocument(doc);
|
||||
iw.commit();
|
||||
iw.addDocument(new Document());
|
||||
iw.commit();
|
||||
try (IndexReader reader = iw.getReader()) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
assertEquals(1, searcher.count(new KnnVectorFieldExistsQuery("vector")));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private float[] randomVector(int dim) {
|
||||
float[] v = new float[dim];
|
||||
for (int i = 0; i < dim; i++) {
|
||||
v[i] = random().nextFloat();
|
||||
}
|
||||
VectorUtil.l2normalize(v);
|
||||
return v;
|
||||
}
|
||||
|
||||
private void assertSameMatches(IndexSearcher searcher, Query q1, Query q2, boolean scores)
|
||||
throws IOException {
|
||||
final int maxDoc = searcher.getIndexReader().maxDoc();
|
||||
final TopDocs td1 = searcher.search(q1, maxDoc, scores ? Sort.RELEVANCE : Sort.INDEXORDER);
|
||||
final TopDocs td2 = searcher.search(q2, maxDoc, scores ? Sort.RELEVANCE : Sort.INDEXORDER);
|
||||
assertEquals(td1.totalHits.value, td2.totalHits.value);
|
||||
for (int i = 0; i < td1.scoreDocs.length; ++i) {
|
||||
assertEquals(td1.scoreDocs[i].doc, td2.scoreDocs[i].doc);
|
||||
if (scores) {
|
||||
assertEquals(td1.scoreDocs[i].score, td2.scoreDocs[i].score, 10e-7);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -27,7 +27,9 @@ import java.util.HashSet;
|
|||
import java.util.Set;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.IntPoint;
|
||||
import org.apache.lucene.document.KnnVectorField;
|
||||
import org.apache.lucene.document.NumericDocValuesField;
|
||||
import org.apache.lucene.document.StringField;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
import org.apache.lucene.index.FilterDirectoryReader;
|
||||
|
@ -36,6 +38,7 @@ import org.apache.lucene.index.IndexReader;
|
|||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.store.Directory;
|
||||
|
@ -91,7 +94,7 @@ public class TestKnnVectorQuery extends LuceneTestCase {
|
|||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
KnnVectorQuery kvq = new KnnVectorQuery("field", new float[] {0, 0}, 10);
|
||||
assertMatches(searcher, kvq, reader.numDocs());
|
||||
assertMatches(searcher, kvq, 3);
|
||||
TopDocs topDocs = searcher.search(kvq, 3);
|
||||
assertEquals(2, topDocs.scoreDocs[0].doc);
|
||||
assertEquals(0, topDocs.scoreDocs[1].doc);
|
||||
|
@ -99,6 +102,33 @@ public class TestKnnVectorQuery extends LuceneTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
/** Tests that a KnnVectorQuery applies the filter query */
|
||||
public void testSimpleFilter() throws IOException {
|
||||
try (Directory indexStore =
|
||||
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
|
||||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
Query filter = new TermQuery(new Term("id", "id2"));
|
||||
Query kvq = new KnnVectorQuery("field", new float[] {0, 0}, 10, filter);
|
||||
TopDocs topDocs = searcher.search(kvq, 3);
|
||||
assertEquals(1, topDocs.totalHits.value);
|
||||
assertEquals(2, topDocs.scoreDocs[0].doc);
|
||||
}
|
||||
}
|
||||
|
||||
public void testFilterWithNoVectorMatches() throws IOException {
|
||||
try (Directory indexStore =
|
||||
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
|
||||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
|
||||
Query filter = new TermQuery(new Term("other", "value"));
|
||||
Query kvq = new KnnVectorQuery("field", new float[] {0, 0}, 10, filter);
|
||||
TopDocs topDocs = searcher.search(kvq, 3);
|
||||
assertEquals(0, topDocs.totalHits.value);
|
||||
}
|
||||
}
|
||||
|
||||
/** testDimensionMismatch */
|
||||
public void testDimensionMismatch() throws IOException {
|
||||
try (Directory indexStore =
|
||||
|
@ -455,6 +485,78 @@ public class TestKnnVectorQuery extends LuceneTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
/** Tests with random vectors and a random filter. Uses RandomIndexWriter. */
|
||||
public void testRandomWithFilter() throws IOException {
|
||||
int numDocs = 200;
|
||||
int dimension = atLeast(5);
|
||||
int numIters = atLeast(10);
|
||||
try (Directory d = newDirectory()) {
|
||||
RandomIndexWriter w = new RandomIndexWriter(random(), d);
|
||||
for (int i = 0; i < numDocs; i++) {
|
||||
Document doc = new Document();
|
||||
doc.add(new KnnVectorField("field", randomVector(dimension)));
|
||||
doc.add(new NumericDocValuesField("tag", i));
|
||||
doc.add(new IntPoint("tag", i));
|
||||
w.addDocument(doc);
|
||||
}
|
||||
w.forceMerge(1);
|
||||
w.close();
|
||||
|
||||
try (DirectoryReader reader = DirectoryReader.open(d)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
for (int i = 0; i < numIters; i++) {
|
||||
int lower = random().nextInt(50);
|
||||
|
||||
// Test a filter with cost less than k and check we use exact search
|
||||
Query filter1 = IntPoint.newRangeQuery("tag", lower, lower + 8);
|
||||
TopDocs results =
|
||||
searcher.search(
|
||||
new KnnVectorQuery("field", randomVector(dimension), 10, filter1), numDocs);
|
||||
assertEquals(9, results.totalHits.value);
|
||||
assertEquals(results.totalHits.value, results.scoreDocs.length);
|
||||
expectThrows(
|
||||
UnsupportedOperationException.class,
|
||||
() ->
|
||||
searcher.search(
|
||||
new ThrowingKnnVectorQuery("field", randomVector(dimension), 10, filter1),
|
||||
numDocs));
|
||||
|
||||
// Test a restrictive filter and check we use exact search
|
||||
Query filter2 = IntPoint.newRangeQuery("tag", lower, lower + 6);
|
||||
results =
|
||||
searcher.search(
|
||||
new KnnVectorQuery("field", randomVector(dimension), 5, filter2), numDocs);
|
||||
assertEquals(5, results.totalHits.value);
|
||||
assertEquals(results.totalHits.value, results.scoreDocs.length);
|
||||
expectThrows(
|
||||
UnsupportedOperationException.class,
|
||||
() ->
|
||||
searcher.search(
|
||||
new ThrowingKnnVectorQuery("field", randomVector(dimension), 5, filter2),
|
||||
numDocs));
|
||||
|
||||
// Test an unrestrictive filter and check we use approximate search
|
||||
Query filter3 = IntPoint.newRangeQuery("tag", lower, lower + 150);
|
||||
results =
|
||||
searcher.search(
|
||||
new ThrowingKnnVectorQuery("field", randomVector(dimension), 5, filter3),
|
||||
numDocs,
|
||||
new Sort(new SortField("tag", SortField.Type.INT)));
|
||||
assertEquals(5, results.totalHits.value);
|
||||
assertEquals(results.totalHits.value, results.scoreDocs.length);
|
||||
|
||||
for (ScoreDoc scoreDoc : results.scoreDocs) {
|
||||
FieldDoc fieldDoc = (FieldDoc) scoreDoc;
|
||||
assertEquals(1, fieldDoc.fields.length);
|
||||
|
||||
int tag = (int) fieldDoc.fields[0];
|
||||
assertTrue(lower <= tag && tag <= lower + 150);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testDeletes() throws IOException {
|
||||
try (Directory dir = newDirectory();
|
||||
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
|
||||
|
@ -550,6 +652,7 @@ public class TestKnnVectorQuery extends LuceneTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
/** Creates a new directory and adds documents with the given vectors as kNN vector fields */
|
||||
private Directory getIndexStore(String field, float[]... contents) throws IOException {
|
||||
Directory indexStore = newDirectory();
|
||||
RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore);
|
||||
|
@ -559,6 +662,13 @@ public class TestKnnVectorQuery extends LuceneTestCase {
|
|||
doc.add(new StringField("id", "id" + i, Field.Store.NO));
|
||||
writer.addDocument(doc);
|
||||
}
|
||||
// Add some documents without a vector
|
||||
for (int i = 0; i < 5; i++) {
|
||||
Document doc = new Document();
|
||||
doc.add(new StringField("other", "value", Field.Store.NO));
|
||||
writer.addDocument(doc);
|
||||
}
|
||||
|
||||
writer.close();
|
||||
return indexStore;
|
||||
}
|
||||
|
@ -569,6 +679,23 @@ public class TestKnnVectorQuery extends LuceneTestCase {
|
|||
assertEquals(expectedMatches, result.length);
|
||||
}
|
||||
|
||||
/**
|
||||
* A version of {@link KnnVectorQuery} that throws an error when an exact search is run. This
|
||||
* allows us to check what search strategy is being used.
|
||||
*/
|
||||
private static class ThrowingKnnVectorQuery extends KnnVectorQuery {
|
||||
|
||||
public ThrowingKnnVectorQuery(String field, float[] target, int k, Query filter) {
|
||||
super(field, target, k, filter);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
|
||||
throws IOException {
|
||||
throw new UnsupportedOperationException("exact search is not supported");
|
||||
}
|
||||
}
|
||||
|
||||
private static class NoLiveDocsDirectoryReader extends FilterDirectoryReader {
|
||||
|
||||
private NoLiveDocsDirectoryReader(DirectoryReader in) throws IOException {
|
||||
|
|
|
@ -435,7 +435,8 @@ public class KnnGraphTester {
|
|||
TopDocs[] results = new TopDocs[reader.leaves().size()];
|
||||
for (LeafReaderContext ctx : reader.leaves()) {
|
||||
Bits liveDocs = ctx.reader().getLiveDocs();
|
||||
results[ctx.ord] = ctx.reader().searchNearestVectors(field, vector, k + fanout, liveDocs);
|
||||
results[ctx.ord] =
|
||||
ctx.reader().searchNearestVectors(field, vector, k + fanout, liveDocs, Integer.MAX_VALUE);
|
||||
int docBase = ctx.docBase;
|
||||
for (ScoreDoc scoreDoc : results[ctx.ord].scoreDocs) {
|
||||
scoreDoc.doc += docBase;
|
||||
|
|
|
@ -45,6 +45,7 @@ import org.apache.lucene.index.VectorValues;
|
|||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
|
@ -167,7 +168,8 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
vectors.randomAccess(),
|
||||
VectorSimilarityFunction.DOT_PRODUCT,
|
||||
hnsw,
|
||||
null);
|
||||
null,
|
||||
Integer.MAX_VALUE);
|
||||
|
||||
int[] nodes = nn.nodes();
|
||||
assertTrue("Number of found results is not equal to [10].", nodes.length == 10);
|
||||
|
@ -206,7 +208,8 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
vectors.randomAccess(),
|
||||
VectorSimilarityFunction.DOT_PRODUCT,
|
||||
hnsw,
|
||||
acceptOrds);
|
||||
acceptOrds,
|
||||
Integer.MAX_VALUE);
|
||||
int[] nodes = nn.nodes();
|
||||
assertTrue("Number of found results is not equal to [10].", nodes.length == 10);
|
||||
int sum = 0;
|
||||
|
@ -219,6 +222,38 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
assertTrue("sum(result docs)=" + sum, sum < 75);
|
||||
}
|
||||
|
||||
public void testSearchWithSelectiveAcceptOrds() throws IOException {
|
||||
int nDoc = 100;
|
||||
int maxConn = 16;
|
||||
CircularVectorValues vectors = new CircularVectorValues(nDoc);
|
||||
HnswGraphBuilder builder =
|
||||
new HnswGraphBuilder(
|
||||
vectors, VectorSimilarityFunction.DOT_PRODUCT, maxConn, 100, random().nextInt());
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors);
|
||||
// Only mark a few vectors as accepted
|
||||
BitSet acceptOrds = new FixedBitSet(vectors.size);
|
||||
for (int i = 0; i < vectors.size; i += random().nextInt(15, 20)) {
|
||||
acceptOrds.set(i);
|
||||
}
|
||||
|
||||
// Check the search finds all accepted vectors
|
||||
int numAccepted = acceptOrds.cardinality();
|
||||
NeighborQueue nn =
|
||||
HnswGraphSearcher.search(
|
||||
new float[] {1, 0},
|
||||
numAccepted,
|
||||
vectors.randomAccess(),
|
||||
VectorSimilarityFunction.DOT_PRODUCT,
|
||||
hnsw,
|
||||
acceptOrds,
|
||||
Integer.MAX_VALUE);
|
||||
int[] nodes = nn.nodes();
|
||||
assertEquals(numAccepted, nodes.length);
|
||||
for (int node : nodes) {
|
||||
assertTrue("the results include a deleted document: " + node, acceptOrds.get(node));
|
||||
}
|
||||
}
|
||||
|
||||
public void testSearchWithSkewedAcceptOrds() throws IOException {
|
||||
int nDoc = 1000;
|
||||
CircularVectorValues vectors = new CircularVectorValues(nDoc);
|
||||
|
@ -239,7 +274,8 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
vectors.randomAccess(),
|
||||
VectorSimilarityFunction.EUCLIDEAN,
|
||||
hnsw,
|
||||
acceptOrds);
|
||||
acceptOrds,
|
||||
Integer.MAX_VALUE);
|
||||
int[] nodes = nn.nodes();
|
||||
assertTrue("Number of found results is not equal to [10].", nodes.length == 10);
|
||||
int sum = 0;
|
||||
|
@ -252,6 +288,31 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
assertTrue("sum(result docs)=" + sum, sum < 5100);
|
||||
}
|
||||
|
||||
public void testVisitedLimit() throws IOException {
|
||||
int nDoc = 500;
|
||||
int maxConn = 16;
|
||||
CircularVectorValues vectors = new CircularVectorValues(nDoc);
|
||||
HnswGraphBuilder builder =
|
||||
new HnswGraphBuilder(
|
||||
vectors, VectorSimilarityFunction.DOT_PRODUCT, maxConn, 100, random().nextInt());
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors);
|
||||
|
||||
int topK = 50;
|
||||
int visitedLimit = topK + random().nextInt(5);
|
||||
NeighborQueue nn =
|
||||
HnswGraphSearcher.search(
|
||||
new float[] {1, 0},
|
||||
topK,
|
||||
vectors.randomAccess(),
|
||||
VectorSimilarityFunction.DOT_PRODUCT,
|
||||
hnsw,
|
||||
createRandomAcceptOrds(0, vectors.size),
|
||||
visitedLimit);
|
||||
assertTrue(nn.incomplete());
|
||||
// The visited count shouldn't be much over the limit
|
||||
assertTrue(nn.visitedCount() < visitedLimit + 3);
|
||||
}
|
||||
|
||||
public void testBoundsCheckerMax() {
|
||||
BoundsChecker max = BoundsChecker.create(false);
|
||||
float f = random().nextFloat() - 0.5f;
|
||||
|
@ -382,7 +443,8 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
for (int i = 0; i < 100; i++) {
|
||||
float[] query = randomVector(random(), dim);
|
||||
NeighborQueue actual =
|
||||
HnswGraphSearcher.search(query, 100, vectors, similarityFunction, hnsw, acceptOrds);
|
||||
HnswGraphSearcher.search(
|
||||
query, 100, vectors, similarityFunction, hnsw, acceptOrds, Integer.MAX_VALUE);
|
||||
while (actual.size() > topK) {
|
||||
actual.pop();
|
||||
}
|
||||
|
|
|
@ -161,7 +161,8 @@ public class TermVectorLeafReader extends LeafReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs) {
|
||||
public TopDocs searchNearestVectors(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
|
|
@ -1351,7 +1351,8 @@ public class MemoryIndex {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs) {
|
||||
public TopDocs searchNearestVectors(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
|
|
@ -114,10 +114,11 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs) throws IOException {
|
||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
FieldInfo fi = fis.fieldInfo(field);
|
||||
assert fi != null && fi.getVectorDimension() > 0;
|
||||
TopDocs hits = delegate.search(field, target, k, acceptDocs);
|
||||
TopDocs hits = delegate.search(field, target, k, acceptDocs, visitedLimit);
|
||||
assert hits != null;
|
||||
assert hits.scoreDocs.length <= k;
|
||||
return hits;
|
||||
|
|
|
@ -563,7 +563,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
|
||||
// assert that knn search doesn't fail on a field with all deleted docs
|
||||
TopDocs results =
|
||||
leafReader.searchNearestVectors("v", randomVector(3), 1, leafReader.getLiveDocs());
|
||||
leafReader.searchNearestVectors(
|
||||
"v", randomVector(3), 1, leafReader.getLiveDocs(), Integer.MAX_VALUE);
|
||||
assertEquals(0, results.scoreDocs.length);
|
||||
}
|
||||
}
|
||||
|
@ -887,7 +888,9 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
k = numLiveDocsWithVectors;
|
||||
}
|
||||
TopDocs results =
|
||||
ctx.reader().searchNearestVectors(fieldName, randomVector(dimension), k, liveDocs);
|
||||
ctx.reader()
|
||||
.searchNearestVectors(
|
||||
fieldName, randomVector(dimension), k, liveDocs, Integer.MAX_VALUE);
|
||||
assertEquals(Math.min(k, size), results.scoreDocs.length);
|
||||
for (int i = 0; i < k - 1; i++) {
|
||||
assertTrue(results.scoreDocs[i].score >= results.scoreDocs[i + 1].score);
|
||||
|
|
|
@ -223,9 +223,9 @@ class MergeReaderWrapper extends LeafReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs)
|
||||
throws IOException {
|
||||
return in.searchNearestVectors(field, target, k, acceptDocs);
|
||||
public TopDocs searchNearestVectors(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
|
||||
return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -228,7 +228,8 @@ public class QueryUtils {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs) {
|
||||
public TopDocs searchNearestVectors(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue