Move byte vector queries into new KnnByteVectorQuery (#12004)

This commit is contained in:
Benjamin Trent 2022-12-14 03:53:10 -05:00 committed by GitHub
parent 9eeab8c4a6
commit 72968d30ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 2068 additions and 1355 deletions

View File

@ -135,6 +135,10 @@ API Changes
* GITHUB#11984: Improved TimeLimitBulkScorer to check the timeout at exponantial rate.
(Costin Leau)
* GITHUB#12004: Add new KnnByteVectorQuery for querying vector fields that are encoded as BYTE. Removes the ability to
use KnnVectorQuery against fields encoded as BYTE (Ben Trent)
New Features
---------------------
* GITHUB#11795: Add ByteWritesTrackingDirectoryWrapper to expose metrics for bytes merged, flushed, and overall

View File

@ -276,6 +276,12 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
}
@Override
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
throw new UnsupportedOperationException();
}
private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException {
IndexInput bytesSlice =
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);

View File

@ -266,6 +266,12 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
}
@Override
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
throw new UnsupportedOperationException();
}
private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException {
IndexInput bytesSlice =
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);

View File

@ -40,6 +40,7 @@ import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.HnswGraph;
@ -54,13 +55,11 @@ import org.apache.lucene.util.packed.DirectMonotonicReader;
*/
public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
private final FieldInfos fieldInfos;
private final Map<String, FieldEntry> fields = new HashMap<>();
private final IndexInput vectorData;
private final IndexInput vectorIndex;
Lucene92HnswVectorsReader(SegmentReadState state) throws IOException {
this.fieldInfos = state.fieldInfos;
int versionMeta = readMetadata(state);
boolean success = false;
try {
@ -260,18 +259,10 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
}
/** Get knn graph values; used for testing */
public HnswGraph getGraph(String field) throws IOException {
FieldInfo info = fieldInfos.fieldInfo(field);
if (info == null) {
throw new IllegalArgumentException("No such field '" + field + "'");
}
FieldEntry entry = fields.get(field);
if (entry != null && entry.vectorIndexLength > 0) {
return getGraph(entry);
} else {
return HnswGraph.EMPTY;
}
@Override
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
throw new UnsupportedOperationException();
}
private HnswGraph getGraph(FieldEntry entry) throws IOException {

View File

@ -41,6 +41,7 @@ import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.HnswGraph;
@ -55,13 +56,11 @@ import org.apache.lucene.util.packed.DirectMonotonicReader;
*/
public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
private final FieldInfos fieldInfos;
private final Map<String, FieldEntry> fields = new HashMap<>();
private final IndexInput vectorData;
private final IndexInput vectorIndex;
Lucene94HnswVectorsReader(SegmentReadState state) throws IOException {
this.fieldInfos = state.fieldInfos;
int versionMeta = readMetadata(state);
boolean success = false;
try {
@ -249,7 +248,7 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry.size() == 0) {
if (fieldEntry.size() == 0 || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
}
@ -284,18 +283,44 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
}
/** Get knn graph values; used for testing */
public HnswGraph getGraph(String field) throws IOException {
FieldInfo info = fieldInfos.fieldInfo(field);
if (info == null) {
throw new IllegalArgumentException("No such field '" + field + "'");
@Override
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry.size() == 0 || fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
}
FieldEntry entry = fields.get(field);
if (entry != null && entry.vectorIndexLength > 0) {
return getGraph(entry);
} else {
return HnswGraph.EMPTY;
// bound k by total number of vectors to prevent oversizing data structures
k = Math.min(k, fieldEntry.size());
OffHeapVectorValues vectorValues = OffHeapVectorValues.load(fieldEntry, vectorData);
NeighborQueue results =
HnswGraphSearcher.search(
target,
k,
vectorValues,
fieldEntry.vectorEncoding,
fieldEntry.similarityFunction,
getGraph(fieldEntry),
vectorValues.getAcceptOrds(acceptDocs),
visitedLimit);
int i = 0;
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
while (results.size() > 0) {
int node = results.topNode();
float score = results.topScore();
results.pop();
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(vectorValues.ordToDoc(node), score);
}
TotalHits.Relation relation =
results.incomplete()
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
}
private HnswGraph getGraph(FieldEntry entry) throws IOException {

View File

@ -184,6 +184,12 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
return new TopDocs(new TotalHits(numVisited, relation), topScoreDocs);
}
@Override
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
}
@Override
public void checkIntegrity() throws IOException {
IndexInput clone = dataIn.clone();

View File

@ -90,6 +90,12 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
throw new UnsupportedOperationException();
}
@Override
public TopDocs search(
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
throw new UnsupportedOperationException();
}
};
writeField(fieldData.fieldInfo, knnVectorsReader, maxDoc);
@ -185,6 +191,12 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
throw new UnsupportedOperationException();
}
@Override
public TopDocs search(
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
throw new UnsupportedOperationException();
}
@Override
public VectorValues getVectorValues(String field) throws IOException {
return MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);

View File

@ -23,6 +23,7 @@ import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.NamedSPILoader;
/**
@ -103,6 +104,12 @@ public abstract class KnnVectorsFormat implements NamedSPILoader.NamedSPI {
throw new UnsupportedOperationException();
}
@Override
public TopDocs search(
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
throw new UnsupportedOperationException();
}
@Override
public void close() {}

View File

@ -26,6 +26,7 @@ import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
/** Reads vectors from an index. */
public abstract class KnnVectorsReader implements Closeable, Accountable {
@ -80,6 +81,35 @@ public abstract class KnnVectorsReader implements Closeable, Accountable {
public abstract TopDocs search(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
/**
* Return the k nearest neighbor documents as determined by comparison of their vector values for
* this field, to the given vector, by the field's similarity function. The score of each document
* is derived from the vector similarity in a way that ensures scores are positive and that a
* larger score corresponds to a higher ranking.
*
* <p>The search is allowed to be approximate, meaning the results are not guaranteed to be the
* true k closest neighbors. For large values of k (for example when k is close to the total
* number of documents), the search may also retrieve fewer than k documents.
*
* <p>The returned {@link TopDocs} will contain a {@link ScoreDoc} for each nearest neighbor, in
* order of their similarity to the query vector (decreasing scores). The {@link TotalHits}
* contains the number of documents visited during the search. If the search stopped early because
* it hit {@code visitedLimit}, it is indicated through the relation {@code
* TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO}.
*
* <p>The behavior is undefined if the given field doesn't have KNN vectors enabled on its {@link
* FieldInfo}. The return value is never {@code null}.
*
* @param field the vector field to search
* @param target the vector-valued query
* @param k the number of docs to return
* @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null}
* if they are all allowed to match.
* @param visitedLimit the maximum number of nodes that the search is allowed to visit
* @return the k nearest neighbor documents, along with their (similarity-specific) scores.
*/
public abstract TopDocs search(
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
/**
* Returns an instance optimized for merging. This instance may only be consumed in the thread
* that called {@link #getMergeInstance()}.

View File

@ -43,6 +43,7 @@ import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.HnswGraph;
@ -255,6 +256,52 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader {
if (fieldEntry.size() == 0) {
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
}
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
}
// bound k by total number of vectors to prevent oversizing data structures
k = Math.min(k, fieldEntry.size());
OffHeapVectorValues vectorValues = OffHeapVectorValues.load(fieldEntry, vectorData);
NeighborQueue results =
HnswGraphSearcher.search(
target,
k,
vectorValues,
fieldEntry.vectorEncoding,
fieldEntry.similarityFunction,
getGraph(fieldEntry),
vectorValues.getAcceptOrds(acceptDocs),
visitedLimit);
int i = 0;
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
while (results.size() > 0) {
int node = results.topNode();
float score = results.topScore();
results.pop();
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(vectorValues.ordToDoc(node), score);
}
TotalHits.Relation relation =
results.incomplete()
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
}
@Override
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry.size() == 0) {
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
}
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
}
// bound k by total number of vectors to prevent oversizing data structures
k = Math.min(k, fieldEntry.size());

View File

@ -33,10 +33,9 @@ import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
/**
@ -259,12 +258,13 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
@Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
KnnVectorsReader knnVectorsReader = fields.get(field);
if (knnVectorsReader == null) {
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
} else {
return knnVectorsReader.search(field, target, k, acceptDocs, visitedLimit);
}
return fields.get(field).search(field, target, k, acceptDocs, visitedLimit);
}
@Override
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
return fields.get(field).search(field, target, k, acceptDocs, visitedLimit);
}
@Override

View File

@ -2598,18 +2598,25 @@ public final class CheckIndex implements Closeable {
int docCount = 0;
int everyNdoc = Math.max(values.size() / 64, 1);
while (values.nextDoc() != NO_MORE_DOCS) {
float[] vectorValue = values.vectorValue();
// search the first maxNumSearches vectors to exercise the graph
if (values.docID() % everyNdoc == 0) {
TopDocs docs =
reader
.getVectorReader()
.search(fieldInfo.name, vectorValue, 10, null, Integer.MAX_VALUE);
switch (fieldInfo.getVectorEncoding()) {
case FLOAT32 -> reader
.getVectorReader()
.search(
fieldInfo.name, values.vectorValue(), 10, null, Integer.MAX_VALUE);
case BYTE -> reader
.getVectorReader()
.search(
fieldInfo.name, values.binaryValue(), 10, null, Integer.MAX_VALUE);
};
if (docs.scoreDocs.length == 0) {
throw new CheckIndexException(
"Field \"" + fieldInfo.name + "\" failed to search k nearest neighbors");
}
}
float[] vectorValue = values.vectorValue();
int valueLength = vectorValue.length;
if (valueLength != dimension) {
throw new CheckIndexException(

View File

@ -27,6 +27,7 @@ import org.apache.lucene.codecs.StoredFieldsReader;
import org.apache.lucene.codecs.TermVectorsReader;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
/** LeafReader implemented by codec APIs. */
public abstract class CodecReader extends LeafReader {
@ -238,6 +239,19 @@ public abstract class CodecReader extends LeafReader {
return getVectorReader().search(field, target, k, acceptDocs, visitedLimit);
}
@Override
public final TopDocs searchNearestVectors(
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
ensureOpen();
FieldInfo fi = getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
// Field does not exist or does not index vectors
return null;
}
return getVectorReader().search(field, target, k, acceptDocs, visitedLimit);
}
@Override
protected void doClose() throws IOException {}

View File

@ -20,6 +20,7 @@ package org.apache.lucene.index;
import java.io.IOException;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
abstract class DocValuesLeafReader extends LeafReader {
@Override
@ -58,6 +59,12 @@ abstract class DocValuesLeafReader extends LeafReader {
throw new UnsupportedOperationException();
}
@Override
public TopDocs searchNearestVectors(
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public final void checkIntegrity() throws IOException {
throw new UnsupportedOperationException();

View File

@ -357,6 +357,12 @@ public abstract class FilterLeafReader extends LeafReader {
return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
}
@Override
public TopDocs searchNearestVectors(
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
}
@Override
public TermVectors termVectors() throws IOException {
ensureOpen();

View File

@ -21,6 +21,7 @@ import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
/**
* {@code LeafReader} is an abstract class, providing an interface for accessing an index. Search of
@ -235,6 +236,34 @@ public abstract class LeafReader extends IndexReader {
public abstract TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
/**
* Return the k nearest neighbor documents as determined by comparison of their vector values for
* this field, to the given vector, by the field's similarity function. The score of each document
* is derived from the vector similarity in a way that ensures scores are positive and that a
* larger score corresponds to a higher ranking.
*
* <p>The search is allowed to be approximate, meaning the results are not guaranteed to be the
* true k closest neighbors. For large values of k (for example when k is close to the total
* number of documents), the search may also retrieve fewer than k documents.
*
* <p>The returned {@link TopDocs} will contain a {@link ScoreDoc} for each nearest neighbor,
* sorted in order of their similarity to the query vector (decreasing scores). The {@link
* TotalHits} contains the number of documents visited during the search. If the search stopped
* early because it hit {@code visitedLimit}, it is indicated through the relation {@code
* TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO}.
*
* @param field the vector field to search
* @param target the vector-valued query
* @param k the number of docs to return
* @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null}
* if they are all allowed to match.
* @param visitedLimit the maximum number of nodes that the search is allowed to visit
* @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores.
* @lucene.experimental
*/
public abstract TopDocs searchNearestVectors(
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
/**
* Get the {@link FieldInfos} describing all fields in this reader.
*

View File

@ -29,6 +29,7 @@ import java.util.TreeMap;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.Version;
/**
@ -418,6 +419,17 @@ public class ParallelLeafReader extends LeafReader {
: reader.searchNearestVectors(fieldName, target, k, acceptDocs, visitedLimit);
}
@Override
public TopDocs searchNearestVectors(
String fieldName, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
ensureOpen();
LeafReader reader = fieldToReader.get(fieldName);
return reader == null
? null
: reader.searchNearestVectors(fieldName, target, k, acceptDocs, visitedLimit);
}
@Override
public void checkIntegrity() throws IOException {
ensureOpen();

View File

@ -30,6 +30,7 @@ import org.apache.lucene.codecs.StoredFieldsReader;
import org.apache.lucene.codecs.TermVectorsReader;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
/**
* Wraps arbitrary readers for merging. Note that this can cause slow and memory-intensive merges.
@ -173,6 +174,12 @@ public final class SlowCodecReaderWrapper {
return reader.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
}
@Override
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
return reader.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
}
@Override
public void checkIntegrity() {
// We already checkIntegrity the entire reader up front

View File

@ -476,6 +476,12 @@ public final class SortingCodecReader extends FilterCodecReader {
throw new UnsupportedOperationException();
}
@Override
public TopDocs search(
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
throw new UnsupportedOperationException();
}
@Override
public void close() throws IOException {
delegate.close();

View File

@ -0,0 +1,410 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Objects;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
/**
* Uses {@link KnnVectorsReader#search} to perform nearest neighbour search.
*
* <p>This query also allows for performing a kNN search subject to a filter. In this case, it first
* executes the filter for each leaf, then chooses a strategy dynamically:
*
* <ul>
* <li>If the filter cost is less than k, just execute an exact search
* <li>Otherwise run a kNN search subject to the filter
* <li>If the kNN search visits too many vectors without completing, stop and run an exact search
* </ul>
*/
abstract class AbstractKnnVectorQuery extends Query {
private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
protected final String field;
protected final int k;
private final Query filter;
public AbstractKnnVectorQuery(String field, int k, Query filter) {
this.field = field;
this.k = k;
if (k < 1) {
throw new IllegalArgumentException("k must be at least 1, got: " + k);
}
this.filter = filter;
}
@Override
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
IndexReader reader = indexSearcher.getIndexReader();
TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
Weight filterWeight = null;
if (filter != null) {
BooleanQuery booleanQuery =
new BooleanQuery.Builder()
.add(filter, BooleanClause.Occur.FILTER)
.add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER)
.build();
Query rewritten = indexSearcher.rewrite(booleanQuery);
filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
}
for (LeafReaderContext ctx : reader.leaves()) {
TopDocs results = searchLeaf(ctx, filterWeight);
if (ctx.docBase > 0) {
for (ScoreDoc scoreDoc : results.scoreDocs) {
scoreDoc.doc += ctx.docBase;
}
}
perLeafResults[ctx.ord] = results;
}
// Merge sort the results
TopDocs topK = TopDocs.merge(k, perLeafResults);
if (topK.scoreDocs.length == 0) {
return new MatchNoDocsQuery();
}
return createRewrittenQuery(reader, topK);
}
private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight) throws IOException {
Bits liveDocs = ctx.reader().getLiveDocs();
int maxDoc = ctx.reader().maxDoc();
if (filterWeight == null) {
return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE);
}
Scorer scorer = filterWeight.scorer(ctx);
if (scorer == null) {
return NO_RESULTS;
}
BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, maxDoc);
int cost = acceptDocs.cardinality();
if (cost <= k) {
// If there are <= k possible matches, short-circuit and perform exact search, since HNSW
// must always visit at least k documents
return exactSearch(ctx, new BitSetIterator(acceptDocs, cost));
}
// Perform the approximate kNN search
TopDocs results = approximateSearch(ctx, acceptDocs, cost);
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) {
return results;
} else {
// We stopped the kNN search because it visited too many nodes, so fall back to exact search
return exactSearch(ctx, new BitSetIterator(acceptDocs, cost));
}
}
private BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc)
throws IOException {
if (liveDocs == null && iterator instanceof BitSetIterator bitSetIterator) {
// If we already have a BitSet and no deletions, reuse the BitSet
return bitSetIterator.getBitSet();
} else {
// Create a new BitSet from matching and live docs
FilteredDocIdSetIterator filterIterator =
new FilteredDocIdSetIterator(iterator) {
@Override
protected boolean match(int doc) {
return liveDocs == null || liveDocs.get(doc);
}
};
return BitSet.of(filterIterator, maxDoc);
}
}
protected abstract TopDocs approximateSearch(
LeafReaderContext context, Bits acceptDocs, int visitedLimit) throws IOException;
abstract VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi)
throws IOException;
// We allow this to be overridden so that tests can check what search strategy is used
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
throws IOException {
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
// The field does not exist or does not index vectors
return NO_RESULTS;
}
VectorScorer vectorScorer = createVectorScorer(context, fi);
HitQueue queue = new HitQueue(k, true);
ScoreDoc topDoc = queue.top();
int doc;
while ((doc = acceptIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
boolean advanced = vectorScorer.advanceExact(doc);
assert advanced;
float score = vectorScorer.score();
if (score > topDoc.score) {
topDoc.score = score;
topDoc.doc = doc;
topDoc = queue.updateTop();
}
}
// Remove any remaining sentinel values
while (queue.size() > 0 && queue.top().score < 0) {
queue.pop();
}
ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()];
for (int i = topScoreDocs.length - 1; i >= 0; i--) {
topScoreDocs[i] = queue.pop();
}
TotalHits totalHits = new TotalHits(acceptIterator.cost(), TotalHits.Relation.EQUAL_TO);
return new TopDocs(totalHits, topScoreDocs);
}
private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
int len = topK.scoreDocs.length;
Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
int[] docs = new int[len];
float[] scores = new float[len];
for (int i = 0; i < len; i++) {
docs[i] = topK.scoreDocs[i].doc;
scores[i] = topK.scoreDocs[i].score;
}
int[] segmentStarts = findSegmentStarts(reader, docs);
return new DocAndScoreQuery(k, docs, scores, segmentStarts, reader.getContext().id());
}
private int[] findSegmentStarts(IndexReader reader, int[] docs) {
int[] starts = new int[reader.leaves().size() + 1];
starts[starts.length - 1] = docs.length;
if (starts.length == 2) {
return starts;
}
int resultIndex = 0;
for (int i = 1; i < starts.length - 1; i++) {
int upper = reader.leaves().get(i).docBase;
resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper);
if (resultIndex < 0) {
resultIndex = -1 - resultIndex;
}
starts[i] = resultIndex;
}
return starts;
}
@Override
public void visit(QueryVisitor visitor) {
if (visitor.acceptField(field)) {
visitor.visitLeaf(this);
}
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
AbstractKnnVectorQuery that = (AbstractKnnVectorQuery) o;
return k == that.k && Objects.equals(field, that.field) && Objects.equals(filter, that.filter);
}
@Override
public int hashCode() {
return Objects.hash(field, k, filter);
}
/** Caches the results of a KnnVector search: a list of docs and their scores */
static class DocAndScoreQuery extends Query {
private final int k;
private final int[] docs;
private final float[] scores;
private final int[] segmentStarts;
private final Object contextIdentity;
/**
* Constructor
*
* @param k the number of documents requested
* @param docs the global docids of documents that match, in ascending order
* @param scores the scores of the matching documents
* @param segmentStarts the indexes in docs and scores corresponding to the first matching
* document in each segment. If a segment has no matching documents, it should be assigned
* the index of the next segment that does. There should be a final entry that is always
* docs.length-1.
* @param contextIdentity an object identifying the reader context that was used to build this
* query
*/
DocAndScoreQuery(
int k, int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) {
this.k = k;
this.docs = docs;
this.scores = scores;
this.segmentStarts = segmentStarts;
this.contextIdentity = contextIdentity;
}
@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
throws IOException {
if (searcher.getIndexReader().getContext().id() != contextIdentity) {
throw new IllegalStateException("This DocAndScore query was created by a different reader");
}
return new Weight(this) {
@Override
public Explanation explain(LeafReaderContext context, int doc) {
int found = Arrays.binarySearch(docs, doc + context.docBase);
if (found < 0) {
return Explanation.noMatch("not in top " + k);
}
return Explanation.match(scores[found] * boost, "within top " + k);
}
@Override
public Scorer scorer(LeafReaderContext context) {
return new Scorer(this) {
final int lower = segmentStarts[context.ord];
final int upper = segmentStarts[context.ord + 1];
int upTo = -1;
@Override
public DocIdSetIterator iterator() {
return new DocIdSetIterator() {
@Override
public int docID() {
return docIdNoShadow();
}
@Override
public int nextDoc() {
if (upTo == -1) {
upTo = lower;
} else {
++upTo;
}
return docIdNoShadow();
}
@Override
public int advance(int target) throws IOException {
return slowAdvance(target);
}
@Override
public long cost() {
return upper - lower;
}
};
}
@Override
public float getMaxScore(int docId) {
docId += context.docBase;
float maxScore = 0;
for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docId; idx++) {
maxScore = Math.max(maxScore, scores[idx]);
}
return maxScore * boost;
}
@Override
public float score() {
return scores[upTo] * boost;
}
@Override
public int advanceShallow(int docid) {
int start = Math.max(upTo, lower);
int docidIndex = Arrays.binarySearch(docs, start, upper, docid + context.docBase);
if (docidIndex < 0) {
docidIndex = -1 - docidIndex;
}
if (docidIndex >= upper) {
return NO_MORE_DOCS;
}
return docs[docidIndex];
}
/**
* move the implementation of docID() into a differently-named method so we can call it
* from DocIDSetIterator.docID() even though this class is anonymous
*
* @return the current docid
*/
private int docIdNoShadow() {
if (upTo == -1) {
return -1;
}
if (upTo >= upper) {
return NO_MORE_DOCS;
}
return docs[upTo] - context.docBase;
}
@Override
public int docID() {
return docIdNoShadow();
}
};
}
@Override
public boolean isCacheable(LeafReaderContext ctx) {
return true;
}
};
}
@Override
public String toString(String field) {
return "DocAndScore[" + k + "]";
}
@Override
public void visit(QueryVisitor visitor) {
visitor.visitLeaf(this);
}
@Override
public boolean equals(Object obj) {
if (sameClassAs(obj) == false) {
return false;
}
return contextIdentity == ((DocAndScoreQuery) obj).contextIdentity
&& Arrays.equals(docs, ((DocAndScoreQuery) obj).docs)
&& Arrays.equals(scores, ((DocAndScoreQuery) obj).scores);
}
@Override
public int hashCode() {
return Objects.hash(
classHash(), contextIdentity, Arrays.hashCode(docs), Arrays.hashCode(scores));
}
}
}

View File

@ -0,0 +1,116 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import java.util.Objects;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
/**
* Uses {@link KnnVectorsReader#search(String, BytesRef, int, Bits, int)} to perform nearest
* neighbour search.
*
* <p>This query also allows for performing a kNN search subject to a filter. In this case, it first
* executes the filter for each leaf, then chooses a strategy dynamically:
*
* <ul>
* <li>If the filter cost is less than k, just execute an exact search
* <li>Otherwise run a kNN search subject to the filter
* <li>If the kNN search visits too many vectors without completing, stop and run an exact search
* </ul>
*/
public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
private final BytesRef target;
/**
* Find the <code>k</code> nearest documents to the target vector according to the vectors in the
* given field. <code>target</code> vector.
*
* @param field a field that has been indexed as a {@link KnnVectorField}.
* @param target the target of the search
* @param k the number of documents to find
* @throws IllegalArgumentException if <code>k</code> is less than 1
*/
public KnnByteVectorQuery(String field, byte[] target, int k) {
this(field, target, k, null);
}
/**
* Find the <code>k</code> nearest documents to the target vector according to the vectors in the
* given field. <code>target</code> vector.
*
* @param field a field that has been indexed as a {@link KnnVectorField}.
* @param target the target of the search
* @param k the number of documents to find
* @param filter a filter applied before the vector search
* @throws IllegalArgumentException if <code>k</code> is less than 1
*/
public KnnByteVectorQuery(String field, byte[] target, int k, Query filter) {
super(field, k, filter);
this.target = new BytesRef(target);
}
@Override
protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit)
throws IOException {
TopDocs results =
context.reader().searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
return results != null ? results : NO_RESULTS;
}
@Override
VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException {
if (fi.getVectorEncoding() != VectorEncoding.BYTE) {
return null;
}
return VectorScorer.create(context, fi, target);
}
@Override
public String toString(String field) {
return getClass().getSimpleName()
+ ":"
+ this.field
+ "["
+ target.bytes[target.offset]
+ ",...]["
+ k
+ "]";
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (super.equals(o) == false) return false;
KnnByteVectorQuery that = (KnnByteVectorQuery) o;
return Objects.equals(target, that.target);
}
@Override
public int hashCode() {
return Objects.hash(super.hashCode(), target);
}
}

View File

@ -16,23 +16,18 @@
*/
package org.apache.lucene.search;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Objects;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.util.Bits;
/**
* Uses {@link KnnVectorsReader#search} to perform nearest neighbour search.
* Uses {@link KnnVectorsReader#search(String, float[], int, Bits, int)} to perform nearest
* neighbour search.
*
* <p>This query also allows for performing a kNN search subject to a filter. In this case, it first
* executes the filter for each leaf, then chooses a strategy dynamically:
@ -43,14 +38,11 @@ import org.apache.lucene.util.Bits;
* <li>If the kNN search visits too many vectors without completing, stop and run an exact search
* </ul>
*/
public class KnnVectorQuery extends Query {
public class KnnVectorQuery extends AbstractKnnVectorQuery {
private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
private final String field;
private final float[] target;
private final int k;
private final Query filter;
/**
* Find the <code>k</code> nearest documents to the target vector according to the vectors in the
@ -76,173 +68,24 @@ public class KnnVectorQuery extends Query {
* @throws IllegalArgumentException if <code>k</code> is less than 1
*/
public KnnVectorQuery(String field, float[] target, int k, Query filter) {
this.field = field;
super(field, k, filter);
this.target = target;
this.k = k;
if (k < 1) {
throw new IllegalArgumentException("k must be at least 1, got: " + k);
}
this.filter = filter;
}
@Override
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
IndexReader reader = indexSearcher.getIndexReader();
TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
Weight filterWeight = null;
if (filter != null) {
BooleanQuery booleanQuery =
new BooleanQuery.Builder()
.add(filter, BooleanClause.Occur.FILTER)
.add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER)
.build();
Query rewritten = indexSearcher.rewrite(booleanQuery);
filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
}
for (LeafReaderContext ctx : reader.leaves()) {
TopDocs results = searchLeaf(ctx, filterWeight);
if (ctx.docBase > 0) {
for (ScoreDoc scoreDoc : results.scoreDocs) {
scoreDoc.doc += ctx.docBase;
}
}
perLeafResults[ctx.ord] = results;
}
// Merge sort the results
TopDocs topK = TopDocs.merge(k, perLeafResults);
if (topK.scoreDocs.length == 0) {
return new MatchNoDocsQuery();
}
return createRewrittenQuery(reader, topK);
}
private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight) throws IOException {
Bits liveDocs = ctx.reader().getLiveDocs();
int maxDoc = ctx.reader().maxDoc();
if (filterWeight == null) {
return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE);
}
Scorer scorer = filterWeight.scorer(ctx);
if (scorer == null) {
return NO_RESULTS;
}
BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, maxDoc);
int cost = acceptDocs.cardinality();
if (cost <= k) {
// If there are <= k possible matches, short-circuit and perform exact search, since HNSW
// must always visit at least k documents
return exactSearch(ctx, new BitSetIterator(acceptDocs, cost));
}
// Perform the approximate kNN search
TopDocs results = approximateSearch(ctx, acceptDocs, cost);
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) {
return results;
} else {
// We stopped the kNN search because it visited too many nodes, so fall back to exact search
return exactSearch(ctx, new BitSetIterator(acceptDocs, cost));
}
}
private BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc)
throws IOException {
if (liveDocs == null && iterator instanceof BitSetIterator bitSetIterator) {
// If we already have a BitSet and no deletions, reuse the BitSet
return bitSetIterator.getBitSet();
} else {
// Create a new BitSet from matching and live docs
FilteredDocIdSetIterator filterIterator =
new FilteredDocIdSetIterator(iterator) {
@Override
protected boolean match(int doc) {
return liveDocs == null || liveDocs.get(doc);
}
};
return BitSet.of(filterIterator, maxDoc);
}
}
private TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit)
protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit)
throws IOException {
TopDocs results =
context.reader().searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
return results != null ? results : NO_RESULTS;
}
// We allow this to be overridden so that tests can check what search strategy is used
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
throws IOException {
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
// The field does not exist or does not index vectors
return NO_RESULTS;
@Override
VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException {
if (fi.getVectorEncoding() != VectorEncoding.FLOAT32) {
return null;
}
VectorScorer vectorScorer = VectorScorer.create(context, fi, target);
HitQueue queue = new HitQueue(k, true);
ScoreDoc topDoc = queue.top();
int doc;
while ((doc = acceptIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
boolean advanced = vectorScorer.advanceExact(doc);
assert advanced;
float score = vectorScorer.score();
if (score > topDoc.score) {
topDoc.score = score;
topDoc.doc = doc;
topDoc = queue.updateTop();
}
}
// Remove any remaining sentinel values
while (queue.size() > 0 && queue.top().score < 0) {
queue.pop();
}
ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()];
for (int i = topScoreDocs.length - 1; i >= 0; i--) {
topScoreDocs[i] = queue.pop();
}
TotalHits totalHits = new TotalHits(acceptIterator.cost(), TotalHits.Relation.EQUAL_TO);
return new TopDocs(totalHits, topScoreDocs);
}
private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
int len = topK.scoreDocs.length;
Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
int[] docs = new int[len];
float[] scores = new float[len];
for (int i = 0; i < len; i++) {
docs[i] = topK.scoreDocs[i].doc;
scores[i] = topK.scoreDocs[i].score;
}
int[] segmentStarts = findSegmentStarts(reader, docs);
return new DocAndScoreQuery(k, docs, scores, segmentStarts, reader.getContext().id());
}
private int[] findSegmentStarts(IndexReader reader, int[] docs) {
int[] starts = new int[reader.leaves().size() + 1];
starts[starts.length - 1] = docs.length;
if (starts.length == 2) {
return starts;
}
int resultIndex = 0;
for (int i = 1; i < starts.length - 1; i++) {
int upper = reader.leaves().get(i).docBase;
resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper);
if (resultIndex < 0) {
resultIndex = -1 - resultIndex;
}
starts[i] = resultIndex;
}
return starts;
return VectorScorer.create(context, fi, target);
}
@Override
@ -251,195 +94,17 @@ public class KnnVectorQuery extends Query {
}
@Override
public void visit(QueryVisitor visitor) {
if (visitor.acceptField(field)) {
visitor.visitLeaf(this);
}
}
@Override
public boolean equals(Object obj) {
if (sameClassAs(obj) == false) {
return false;
}
return ((KnnVectorQuery) obj).k == k
&& ((KnnVectorQuery) obj).field.equals(field)
&& Arrays.equals(((KnnVectorQuery) obj).target, target)
&& Objects.equals(filter, ((KnnVectorQuery) obj).filter);
public boolean equals(Object o) {
if (this == o) return true;
if (super.equals(o) == false) return false;
KnnVectorQuery that = (KnnVectorQuery) o;
return Arrays.equals(target, that.target);
}
@Override
public int hashCode() {
return Objects.hash(classHash(), field, k, Arrays.hashCode(target), filter);
}
/** Caches the results of a KnnVector search: a list of docs and their scores */
static class DocAndScoreQuery extends Query {
private final int k;
private final int[] docs;
private final float[] scores;
private final int[] segmentStarts;
private final Object contextIdentity;
/**
* Constructor
*
* @param k the number of documents requested
* @param docs the global docids of documents that match, in ascending order
* @param scores the scores of the matching documents
* @param segmentStarts the indexes in docs and scores corresponding to the first matching
* document in each segment. If a segment has no matching documents, it should be assigned
* the index of the next segment that does. There should be a final entry that is always
* docs.length-1.
* @param contextIdentity an object identifying the reader context that was used to build this
* query
*/
DocAndScoreQuery(
int k, int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) {
this.k = k;
this.docs = docs;
this.scores = scores;
this.segmentStarts = segmentStarts;
this.contextIdentity = contextIdentity;
}
@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
throws IOException {
if (searcher.getIndexReader().getContext().id() != contextIdentity) {
throw new IllegalStateException("This DocAndScore query was created by a different reader");
}
return new Weight(this) {
@Override
public Explanation explain(LeafReaderContext context, int doc) {
int found = Arrays.binarySearch(docs, doc + context.docBase);
if (found < 0) {
return Explanation.noMatch("not in top " + k);
}
return Explanation.match(scores[found] * boost, "within top " + k);
}
@Override
public Scorer scorer(LeafReaderContext context) {
return new Scorer(this) {
final int lower = segmentStarts[context.ord];
final int upper = segmentStarts[context.ord + 1];
int upTo = -1;
@Override
public DocIdSetIterator iterator() {
return new DocIdSetIterator() {
@Override
public int docID() {
return docIdNoShadow();
}
@Override
public int nextDoc() {
if (upTo == -1) {
upTo = lower;
} else {
++upTo;
}
return docIdNoShadow();
}
@Override
public int advance(int target) throws IOException {
return slowAdvance(target);
}
@Override
public long cost() {
return upper - lower;
}
};
}
@Override
public float getMaxScore(int docId) {
docId += context.docBase;
float maxScore = 0;
for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docId; idx++) {
maxScore = Math.max(maxScore, scores[idx]);
}
return maxScore * boost;
}
@Override
public float score() {
return scores[upTo] * boost;
}
@Override
public int advanceShallow(int docid) {
int start = Math.max(upTo, lower);
int docidIndex = Arrays.binarySearch(docs, start, upper, docid + context.docBase);
if (docidIndex < 0) {
docidIndex = -1 - docidIndex;
}
if (docidIndex >= upper) {
return NO_MORE_DOCS;
}
return docs[docidIndex];
}
/**
* move the implementation of docID() into a differently-named method so we can call it
* from DocIDSetIterator.docID() even though this class is anonymous
*
* @return the current docid
*/
private int docIdNoShadow() {
if (upTo == -1) {
return -1;
}
if (upTo >= upper) {
return NO_MORE_DOCS;
}
return docs[upTo] - context.docBase;
}
@Override
public int docID() {
return docIdNoShadow();
}
};
}
@Override
public boolean isCacheable(LeafReaderContext ctx) {
return true;
}
};
}
@Override
public String toString(String field) {
return "DocAndScore[" + k + "]";
}
@Override
public void visit(QueryVisitor visitor) {
visitor.visitLeaf(this);
}
@Override
public boolean equals(Object obj) {
if (sameClassAs(obj) == false) {
return false;
}
return contextIdentity == ((DocAndScoreQuery) obj).contextIdentity
&& Arrays.equals(docs, ((DocAndScoreQuery) obj).docs)
&& Arrays.equals(scores, ((DocAndScoreQuery) obj).scores);
}
@Override
public int hashCode() {
return Objects.hash(
classHash(), contextIdentity, Arrays.hashCode(docs), Arrays.hashCode(scores));
}
int result = super.hashCode();
result = 31 * result + Arrays.hashCode(target);
return result;
}
}

View File

@ -22,7 +22,6 @@ import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.VectorUtil;
/**
* Computes the similarity score between a given query vector and different document vectors. This
@ -40,14 +39,18 @@ abstract class VectorScorer {
* @param fi the FieldInfo for the field containing document vectors
* @param query the query vector to compute the similarity for
*/
static VectorScorer create(LeafReaderContext context, FieldInfo fi, float[] query)
static FloatVectorScorer create(LeafReaderContext context, FieldInfo fi, float[] query)
throws IOException {
VectorValues values = context.reader().getVectorValues(fi.name);
final VectorSimilarityFunction similarity = fi.getVectorSimilarityFunction();
return new FloatVectorScorer(values, query, similarity);
}
static ByteVectorScorer create(LeafReaderContext context, FieldInfo fi, BytesRef query)
throws IOException {
VectorValues values = context.reader().getVectorValues(fi.name);
VectorSimilarityFunction similarity = fi.getVectorSimilarityFunction();
return switch (fi.getVectorEncoding()) {
case BYTE -> new ByteVectorScorer(values, query, similarity);
case FLOAT32 -> new FloatVectorScorer(values, query, similarity);
};
return new ByteVectorScorer(values, query, similarity);
}
VectorScorer(VectorValues values, VectorSimilarityFunction similarity) {
@ -74,9 +77,9 @@ abstract class VectorScorer {
private final BytesRef query;
protected ByteVectorScorer(
VectorValues values, float[] query, VectorSimilarityFunction similarity) {
VectorValues values, BytesRef query, VectorSimilarityFunction similarity) {
super(values, similarity);
this.query = VectorUtil.toBytesRef(query);
this.query = query;
}
@Override

View File

@ -18,7 +18,6 @@
package org.apache.lucene.util.hnsw;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.util.VectorUtil.toBytesRef;
import java.io.IOException;
import org.apache.lucene.index.VectorEncoding;
@ -96,17 +95,6 @@ public class HnswGraphSearcher<T> {
+ " differs from field dimension: "
+ vectors.dimension());
}
if (vectorEncoding == VectorEncoding.BYTE) {
return search(
toBytesRef(query),
topK,
vectors,
vectorEncoding,
similarityFunction,
graph,
acceptOrds,
visitedLimit);
}
HnswGraphSearcher<float[]> graphSearcher =
new HnswGraphSearcher<>(
vectorEncoding,
@ -132,7 +120,21 @@ public class HnswGraphSearcher<T> {
return results;
}
private static NeighborQueue search(
/**
* Searches HNSW graph for the nearest neighbors of a query vector.
*
* @param query search query vector
* @param topK the number of nodes to be returned
* @param vectors the vector values
* @param similarityFunction the similarity function to compare vectors
* @param graph the graph values. May represent the entire graph, or a level in a hierarchical
* graph.
* @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or
* {@code null} if they are all allowed to match.
* @param visitedLimit the maximum number of nodes that the search is allowed to visit
* @return a priority queue holding the closest neighbors found
*/
public static NeighborQueue search(
BytesRef query,
int topK,
RandomAccessVectorValues vectors,
@ -142,6 +144,13 @@ public class HnswGraphSearcher<T> {
Bits acceptOrds,
int visitedLimit)
throws IOException {
if (query.length != vectors.dimension()) {
throw new IllegalArgumentException(
"vector query dimension: "
+ query.length
+ " differs from field dimension: "
+ vectors.dimension());
}
HnswGraphSearcher<BytesRef> graphSearcher =
new HnswGraphSearcher<>(
vectorEncoding,

View File

@ -33,6 +33,7 @@ import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.NamedThreadFactory;
import org.apache.lucene.util.Version;
@ -117,6 +118,12 @@ public class TestSegmentToThreadMapping extends LuceneTestCase {
return null;
}
@Override
public TopDocs searchNearestVectors(
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
return null;
}
@Override
protected void doClose() {}

View File

@ -0,0 +1,887 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import static com.carrotsearch.randomizedtesting.RandomizedTest.frequently;
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.HashSet;
import java.util.Set;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.IntPoint;
import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.FilterDirectoryReader;
import org.apache.lucene.index.FilterLeafReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.StoredFields;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
/** Test cases for KnnVectorQuery objects. */
abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
abstract AbstractKnnVectorQuery getKnnVectorQuery(
String field, float[] query, int k, Query queryFilter);
abstract AbstractKnnVectorQuery getThrowingKnnVectorQuery(
String field, float[] query, int k, Query queryFilter);
AbstractKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k) {
return getKnnVectorQuery(field, query, k, null);
}
abstract float[] randomVector(int dim);
abstract VectorEncoding vectorEncoding();
abstract Field getKnnVectorField(
String name, float[] vector, VectorSimilarityFunction similarityFunction);
abstract Field getKnnVectorField(String name, float[] vector);
public void testEquals() {
AbstractKnnVectorQuery q1 = getKnnVectorQuery("f1", new float[] {0, 1}, 10);
Query filter1 = new TermQuery(new Term("id", "id1"));
AbstractKnnVectorQuery q2 = getKnnVectorQuery("f1", new float[] {0, 1}, 10, filter1);
assertNotEquals(q2, q1);
assertNotEquals(q1, q2);
assertEquals(q2, getKnnVectorQuery("f1", new float[] {0, 1}, 10, filter1));
Query filter2 = new TermQuery(new Term("id", "id2"));
assertNotEquals(q2, getKnnVectorQuery("f1", new float[] {0, 1}, 10, filter2));
assertEquals(q1, getKnnVectorQuery("f1", new float[] {0, 1}, 10));
assertNotEquals(null, q1);
assertNotEquals(q1, new TermQuery(new Term("f1", "x")));
assertNotEquals(q1, getKnnVectorQuery("f2", new float[] {0, 1}, 10));
assertNotEquals(q1, getKnnVectorQuery("f1", new float[] {1, 1}, 10));
assertNotEquals(q1, getKnnVectorQuery("f1", new float[] {0, 1}, 2));
assertNotEquals(q1, getKnnVectorQuery("f1", new float[] {0}, 10));
}
/**
* Tests if a AbstractKnnVectorQuery is rewritten to a MatchNoDocsQuery when there are no
* documents to match.
*/
public void testEmptyIndex() throws IOException {
try (Directory indexStore = getIndexStore("field");
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
AbstractKnnVectorQuery kvq = getKnnVectorQuery("field", new float[] {1, 2}, 10);
assertMatches(searcher, kvq, 0);
Query q = searcher.rewrite(kvq);
assertTrue(q instanceof MatchNoDocsQuery);
}
}
/**
* Tests that a AbstractKnnVectorQuery whose topK &gt;= numDocs returns all the documents in score
* order
*/
public void testFindAll() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
AbstractKnnVectorQuery kvq = getKnnVectorQuery("field", new float[] {0, 0}, 10);
assertMatches(searcher, kvq, 3);
ScoreDoc[] scoreDocs = searcher.search(kvq, 3).scoreDocs;
assertIdMatches(reader, "id2", scoreDocs[0]);
assertIdMatches(reader, "id0", scoreDocs[1]);
assertIdMatches(reader, "id1", scoreDocs[2]);
}
}
public void testSearchBoost() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
Query vectorQuery = getKnnVectorQuery("field", new float[] {0, 0}, 10);
ScoreDoc[] scoreDocs = searcher.search(vectorQuery, 3).scoreDocs;
Query boostQuery = new BoostQuery(vectorQuery, 3.0f);
ScoreDoc[] boostScoreDocs = searcher.search(boostQuery, 3).scoreDocs;
assertEquals(scoreDocs.length, boostScoreDocs.length);
for (int i = 0; i < scoreDocs.length; i++) {
ScoreDoc scoreDoc = scoreDocs[i];
ScoreDoc boostScoreDoc = boostScoreDocs[i];
assertEquals(scoreDoc.doc, boostScoreDoc.doc);
assertEquals(scoreDoc.score * 3.0f, boostScoreDoc.score, 0.001f);
}
}
}
/** Tests that a AbstractKnnVectorQuery applies the filter query */
public void testSimpleFilter() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
Query filter = new TermQuery(new Term("id", "id2"));
Query kvq = getKnnVectorQuery("field", new float[] {0, 0}, 10, filter);
TopDocs topDocs = searcher.search(kvq, 3);
assertEquals(1, topDocs.totalHits.value);
assertIdMatches(reader, "id2", topDocs.scoreDocs[0]);
}
}
public void testFilterWithNoVectorMatches() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
Query filter = new TermQuery(new Term("other", "value"));
Query kvq = getKnnVectorQuery("field", new float[] {0, 0}, 10, filter);
TopDocs topDocs = searcher.search(kvq, 3);
assertEquals(0, topDocs.totalHits.value);
}
}
/** testDimensionMismatch */
public void testDimensionMismatch() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
AbstractKnnVectorQuery kvq = getKnnVectorQuery("field", new float[] {0}, 10);
IllegalArgumentException e =
expectThrows(IllegalArgumentException.class, () -> searcher.search(kvq, 10));
assertEquals("vector query dimension: 1 differs from field dimension: 2", e.getMessage());
}
}
/** testNonVectorField */
public void testNonVectorField() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
assertMatches(searcher, getKnnVectorQuery("xyzzy", new float[] {0}, 10), 0);
assertMatches(searcher, getKnnVectorQuery("id", new float[] {0}, 10), 0);
}
}
/** Test bad parameters */
public void testIllegalArguments() throws IOException {
expectThrows(IllegalArgumentException.class, () -> getKnnVectorQuery("xx", new float[] {1}, 0));
}
public void testDifferentReader() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {2, 3}, 3);
Query dasq = query.rewrite(newSearcher(reader));
IndexSearcher leafSearcher = newSearcher(reader.leaves().get(0).reader());
expectThrows(
IllegalStateException.class,
() -> dasq.createWeight(leafSearcher, ScoreMode.COMPLETE, 1));
}
}
public void testAdvanceShallow() throws IOException {
try (Directory d = newDirectory()) {
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
for (int j = 0; j < 5; j++) {
Document doc = new Document();
doc.add(getKnnVectorField("field", new float[] {j, j}));
w.addDocument(doc);
}
}
try (IndexReader reader = DirectoryReader.open(d)) {
IndexSearcher searcher = new IndexSearcher(reader);
AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {2, 3}, 3);
Query dasq = query.rewrite(searcher);
Scorer scorer =
dasq.createWeight(searcher, ScoreMode.COMPLETE, 1).scorer(reader.leaves().get(0));
// before advancing the iterator
assertEquals(1, scorer.advanceShallow(0));
assertEquals(1, scorer.advanceShallow(1));
assertEquals(NO_MORE_DOCS, scorer.advanceShallow(10));
// after advancing the iterator
scorer.iterator().advance(2);
assertEquals(2, scorer.advanceShallow(0));
assertEquals(2, scorer.advanceShallow(2));
assertEquals(3, scorer.advanceShallow(3));
assertEquals(NO_MORE_DOCS, scorer.advanceShallow(10));
}
}
}
public void testScoreEuclidean() throws IOException {
float[][] vectors = new float[5][];
for (int j = 0; j < 5; j++) {
vectors[j] = new float[] {j, j};
}
try (Directory d = getStableIndexStore("field", vectors);
IndexReader reader = DirectoryReader.open(d)) {
IndexSearcher searcher = new IndexSearcher(reader);
AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {2, 3}, 3);
Query rewritten = query.rewrite(searcher);
Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1);
Scorer scorer = weight.scorer(reader.leaves().get(0));
// prior to advancing, score is 0
assertEquals(-1, scorer.docID());
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
// test getMaxScore
assertEquals(0, scorer.getMaxScore(-1), 0);
assertEquals(0, scorer.getMaxScore(0), 0);
// This is 1 / ((l2distance((2,3), (2, 2)) = 1) + 1) = 0.5
assertEquals(1 / 2f, scorer.getMaxScore(2), 0);
assertEquals(1 / 2f, scorer.getMaxScore(Integer.MAX_VALUE), 0);
DocIdSetIterator it = scorer.iterator();
assertEquals(3, it.cost());
assertEquals(1, it.nextDoc());
assertEquals(1 / 6f, scorer.score(), 0);
assertEquals(3, it.advance(3));
assertEquals(1 / 2f, scorer.score(), 0);
assertEquals(NO_MORE_DOCS, it.advance(4));
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
}
}
public void testScoreCosine() throws IOException {
try (Directory d = newDirectory()) {
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
for (int j = 1; j <= 5; j++) {
Document doc = new Document();
doc.add(getKnnVectorField("field", new float[] {j, j * j}, COSINE));
w.addDocument(doc);
}
}
try (IndexReader reader = DirectoryReader.open(d)) {
assertEquals(1, reader.leaves().size());
IndexSearcher searcher = new IndexSearcher(reader);
AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {2, 3}, 3);
Query rewritten = query.rewrite(searcher);
Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1);
Scorer scorer = weight.scorer(reader.leaves().get(0));
// prior to advancing, score is undefined
assertEquals(-1, scorer.docID());
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
// test getMaxScore
assertEquals(0, scorer.getMaxScore(-1), 0);
/* maxAtZero = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)), then
* normalized by (1 + x) /2.
*/
float maxAtZero =
(float) ((1 + (2 * 1 + 3 * 1) / Math.sqrt((2 * 2 + 3 * 3) * (1 * 1 + 1 * 1))) / 2);
assertEquals(maxAtZero, scorer.getMaxScore(0), 0.001);
/* max at 2 is actually the score for doc 1 which is the highest (since doc 1 vector (2, 4)
* is the closest to (2, 3)). This is ((2,3) * (2, 4) = 16) / (||2, 3|| * ||2, 4|| = sqrt(260)), then
* normalized by (1 + x) /2
*/
float expected =
(float) ((1 + (2 * 2 + 3 * 4) / Math.sqrt((2 * 2 + 3 * 3) * (2 * 2 + 4 * 4))) / 2);
assertEquals(expected, scorer.getMaxScore(2), 0);
assertEquals(expected, scorer.getMaxScore(Integer.MAX_VALUE), 0);
DocIdSetIterator it = scorer.iterator();
assertEquals(3, it.cost());
assertEquals(0, it.nextDoc());
// doc 0 has (1, 1)
assertEquals(maxAtZero, scorer.score(), 0.0001);
assertEquals(1, it.advance(1));
assertEquals(expected, scorer.score(), 0);
assertEquals(2, it.nextDoc());
// since topK was 3
assertEquals(NO_MORE_DOCS, it.advance(4));
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
}
}
}
public void testExplain() throws IOException {
try (Directory d = newDirectory()) {
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
for (int j = 0; j < 5; j++) {
Document doc = new Document();
doc.add(getKnnVectorField("field", new float[] {j, j}));
w.addDocument(doc);
}
}
try (IndexReader reader = DirectoryReader.open(d)) {
IndexSearcher searcher = new IndexSearcher(reader);
AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {2, 3}, 3);
Explanation matched = searcher.explain(query, 2);
assertTrue(matched.isMatch());
assertEquals(1 / 2f, matched.getValue());
assertEquals(0, matched.getDetails().length);
assertEquals("within top 3", matched.getDescription());
Explanation nomatch = searcher.explain(query, 4);
assertFalse(nomatch.isMatch());
assertEquals(0f, nomatch.getValue());
assertEquals(0, matched.getDetails().length);
assertEquals("not in top 3", nomatch.getDescription());
}
}
}
public void testExplainMultipleSegments() throws IOException {
try (Directory d = newDirectory()) {
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
for (int j = 0; j < 5; j++) {
Document doc = new Document();
doc.add(getKnnVectorField("field", new float[] {j, j}));
w.addDocument(doc);
w.commit();
}
}
try (IndexReader reader = DirectoryReader.open(d)) {
IndexSearcher searcher = new IndexSearcher(reader);
AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {2, 3}, 3);
Explanation matched = searcher.explain(query, 2);
assertTrue(matched.isMatch());
assertEquals(1 / 2f, matched.getValue());
assertEquals(0, matched.getDetails().length);
assertEquals("within top 3", matched.getDescription());
Explanation nomatch = searcher.explain(query, 4);
assertFalse(nomatch.isMatch());
assertEquals(0f, nomatch.getValue());
assertEquals(0, matched.getDetails().length);
assertEquals("not in top 3", nomatch.getDescription());
}
}
}
/** Test that when vectors are abnormally distributed among segments, we still find the top K */
public void testSkewedIndex() throws IOException {
/* We have to choose the numbers carefully here so that some segment has more than the expected
* number of top K documents, but no more than K documents in total (otherwise we might occasionally
* randomly fail to find one).
*/
try (Directory d = newDirectory()) {
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
int r = 0;
for (int i = 0; i < 5; i++) {
for (int j = 0; j < 5; j++) {
Document doc = new Document();
doc.add(getKnnVectorField("field", new float[] {r, r}));
doc.add(new StringField("id", "id" + r, Field.Store.YES));
w.addDocument(doc);
++r;
}
w.flush();
}
}
try (IndexReader reader = DirectoryReader.open(d)) {
IndexSearcher searcher = newSearcher(reader);
TopDocs results = searcher.search(getKnnVectorQuery("field", new float[] {0, 0}, 8), 10);
assertEquals(8, results.scoreDocs.length);
assertIdMatches(reader, "id0", results.scoreDocs[0]);
assertIdMatches(reader, "id7", results.scoreDocs[7]);
// test some results in the middle of the sequence - also tests docid tiebreaking
results = searcher.search(getKnnVectorQuery("field", new float[] {10, 10}, 8), 10);
assertEquals(8, results.scoreDocs.length);
assertIdMatches(reader, "id10", results.scoreDocs[0]);
assertIdMatches(reader, "id6", results.scoreDocs[7]);
}
}
}
/** Tests with random vectors, number of documents, etc. Uses RandomIndexWriter. */
public void testRandom() throws IOException {
int numDocs = atLeast(100);
int dimension = atLeast(5);
int numIters = atLeast(10);
boolean everyDocHasAVector = random().nextBoolean();
try (Directory d = newDirectory()) {
RandomIndexWriter w = new RandomIndexWriter(random(), d);
for (int i = 0; i < numDocs; i++) {
Document doc = new Document();
if (everyDocHasAVector || random().nextInt(10) != 2) {
doc.add(getKnnVectorField("field", randomVector(dimension)));
}
w.addDocument(doc);
}
w.close();
try (IndexReader reader = DirectoryReader.open(d)) {
IndexSearcher searcher = newSearcher(reader);
for (int i = 0; i < numIters; i++) {
int k = random().nextInt(80) + 1;
AbstractKnnVectorQuery query = getKnnVectorQuery("field", randomVector(dimension), k);
int n = random().nextInt(100) + 1;
TopDocs results = searcher.search(query, n);
int expected = Math.min(Math.min(n, k), reader.numDocs());
// we may get fewer results than requested if there are deletions, but this test doesn't
// test that
assert reader.hasDeletions() == false;
assertEquals(expected, results.scoreDocs.length);
assertTrue(results.totalHits.value >= results.scoreDocs.length);
// verify the results are in descending score order
float last = Float.MAX_VALUE;
for (ScoreDoc scoreDoc : results.scoreDocs) {
assertTrue(scoreDoc.score <= last);
last = scoreDoc.score;
}
}
}
}
}
/** Tests with random vectors and a random filter. Uses RandomIndexWriter. */
public void testRandomWithFilter() throws IOException {
int numDocs = 1000;
int dimension = atLeast(5);
int numIters = atLeast(10);
try (Directory d = newDirectory()) {
// Always use the default kNN format to have predictable behavior around when it hits
// visitedLimit. This is fine since the test targets AbstractKnnVectorQuery logic, not the kNN
// format
// implementation.
IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec());
RandomIndexWriter w = new RandomIndexWriter(random(), d, iwc);
for (int i = 0; i < numDocs; i++) {
Document doc = new Document();
doc.add(getKnnVectorField("field", randomVector(dimension)));
doc.add(new NumericDocValuesField("tag", i));
doc.add(new IntPoint("tag", i));
w.addDocument(doc);
}
w.forceMerge(1);
w.close();
try (DirectoryReader reader = DirectoryReader.open(d)) {
IndexSearcher searcher = newSearcher(reader);
for (int i = 0; i < numIters; i++) {
int lower = random().nextInt(500);
// Test a filter with cost less than k and check we use exact search
Query filter1 = IntPoint.newRangeQuery("tag", lower, lower + 8);
TopDocs results =
searcher.search(
getKnnVectorQuery("field", randomVector(dimension), 10, filter1), numDocs);
assertEquals(9, results.totalHits.value);
assertEquals(results.totalHits.value, results.scoreDocs.length);
expectThrows(
UnsupportedOperationException.class,
() ->
searcher.search(
getThrowingKnnVectorQuery("field", randomVector(dimension), 10, filter1),
numDocs));
// Test a restrictive filter and check we use exact search
Query filter2 = IntPoint.newRangeQuery("tag", lower, lower + 6);
results =
searcher.search(
getKnnVectorQuery("field", randomVector(dimension), 5, filter2), numDocs);
assertEquals(5, results.totalHits.value);
assertEquals(results.totalHits.value, results.scoreDocs.length);
expectThrows(
UnsupportedOperationException.class,
() ->
searcher.search(
getThrowingKnnVectorQuery("field", randomVector(dimension), 5, filter2),
numDocs));
// Test an unrestrictive filter and check we use approximate search
Query filter3 = IntPoint.newRangeQuery("tag", lower, numDocs);
results =
searcher.search(
getThrowingKnnVectorQuery("field", randomVector(dimension), 5, filter3),
numDocs,
new Sort(new SortField("tag", SortField.Type.INT)));
assertEquals(5, results.totalHits.value);
assertEquals(results.totalHits.value, results.scoreDocs.length);
for (ScoreDoc scoreDoc : results.scoreDocs) {
FieldDoc fieldDoc = (FieldDoc) scoreDoc;
assertEquals(1, fieldDoc.fields.length);
int tag = (int) fieldDoc.fields[0];
assertTrue(lower <= tag && tag <= numDocs);
}
// Test a filter that exhausts visitedLimit in upper levels, and switches to exact search
Query filter4 = IntPoint.newRangeQuery("tag", lower, lower + 2);
expectThrows(
UnsupportedOperationException.class,
() ->
searcher.search(
getThrowingKnnVectorQuery("field", randomVector(dimension), 1, filter4),
numDocs));
}
}
}
}
/** Tests filtering when all vectors have the same score. */
@AwaitsFix(bugUrl = "https://github.com/apache/lucene/issues/11787")
public void testFilterWithSameScore() throws IOException {
int numDocs = 100;
int dimension = atLeast(5);
try (Directory d = newDirectory()) {
// Always use the default kNN format to have predictable behavior around when it hits
// visitedLimit. This is fine since the test targets AbstractKnnVectorQuery logic, not the kNN
// format
// implementation.
IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec());
IndexWriter w = new IndexWriter(d, iwc);
float[] vector = randomVector(dimension);
for (int i = 0; i < numDocs; i++) {
Document doc = new Document();
doc.add(getKnnVectorField("field", vector));
doc.add(new IntPoint("tag", i));
w.addDocument(doc);
}
w.forceMerge(1);
w.close();
try (DirectoryReader reader = DirectoryReader.open(d)) {
IndexSearcher searcher = newSearcher(reader);
int lower = random().nextInt(50);
int size = 5;
// Test a restrictive filter, which usually performs exact search
Query filter1 = IntPoint.newRangeQuery("tag", lower, lower + 6);
TopDocs results =
searcher.search(
getKnnVectorQuery("field", randomVector(dimension), size, filter1), size);
assertEquals(size, results.scoreDocs.length);
// Test an unrestrictive filter, which usually performs approximate search
Query filter2 = IntPoint.newRangeQuery("tag", lower, numDocs);
results =
searcher.search(
getKnnVectorQuery("field", randomVector(dimension), size, filter2), size);
assertEquals(size, results.scoreDocs.length);
}
}
}
public void testDeletes() throws IOException {
try (Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
final int numDocs = atLeast(100);
final int dim = 30;
for (int i = 0; i < numDocs; ++i) {
Document d = new Document();
d.add(new StringField("index", String.valueOf(i), Field.Store.YES));
if (frequently()) {
d.add(getKnnVectorField("vector", randomVector(dim)));
}
w.addDocument(d);
}
w.commit();
// Delete some documents at random, both those with and without vectors
Set<Term> toDelete = new HashSet<>();
for (int i = 0; i < 25; i++) {
int index = random().nextInt(numDocs);
toDelete.add(new Term("index", String.valueOf(index)));
}
w.deleteDocuments(toDelete.toArray(new Term[0]));
w.commit();
int hits = 50;
try (IndexReader reader = DirectoryReader.open(dir)) {
Set<String> allIds = new HashSet<>();
IndexSearcher searcher = new IndexSearcher(reader);
AbstractKnnVectorQuery query = getKnnVectorQuery("vector", randomVector(dim), hits);
TopDocs topDocs = searcher.search(query, numDocs);
StoredFields storedFields = reader.storedFields();
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
Document doc = storedFields.document(scoreDoc.doc, Set.of("index"));
String index = doc.get("index");
assertFalse(
"search returned a deleted document: " + index,
toDelete.contains(new Term("index", index)));
allIds.add(index);
}
assertEquals("search missed some documents", hits, allIds.size());
}
}
}
public void testAllDeletes() throws IOException {
try (Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
final int numDocs = atLeast(100);
final int dim = 30;
for (int i = 0; i < numDocs; ++i) {
Document d = new Document();
d.add(getKnnVectorField("vector", randomVector(dim)));
w.addDocument(d);
}
w.commit();
w.deleteDocuments(new MatchAllDocsQuery());
w.commit();
try (IndexReader reader = DirectoryReader.open(dir)) {
IndexSearcher searcher = new IndexSearcher(reader);
AbstractKnnVectorQuery query = getKnnVectorQuery("vector", randomVector(dim), numDocs);
TopDocs topDocs = searcher.search(query, numDocs);
assertEquals(0, topDocs.scoreDocs.length);
}
}
}
/**
* Check that the query behaves reasonably when using a custom filter reader where there are no
* live docs.
*/
public void testNoLiveDocsReader() throws IOException {
IndexWriterConfig iwc = newIndexWriterConfig();
try (Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, iwc)) {
final int numDocs = 10;
final int dim = 30;
for (int i = 0; i < numDocs; ++i) {
Document d = new Document();
d.add(new StringField("index", String.valueOf(i), Field.Store.NO));
d.add(getKnnVectorField("vector", randomVector(dim)));
w.addDocument(d);
}
w.commit();
try (DirectoryReader reader = DirectoryReader.open(dir)) {
DirectoryReader wrappedReader = new NoLiveDocsDirectoryReader(reader);
IndexSearcher searcher = new IndexSearcher(wrappedReader);
AbstractKnnVectorQuery query = getKnnVectorQuery("vector", randomVector(dim), numDocs);
TopDocs topDocs = searcher.search(query, numDocs);
assertEquals(0, topDocs.scoreDocs.length);
}
}
}
/**
* Test that AbstractKnnVectorQuery optimizes the case where the filter query is backed by {@link
* BitSetIterator}.
*/
public void testBitSetQuery() throws IOException {
IndexWriterConfig iwc = newIndexWriterConfig();
try (Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, iwc)) {
final int numDocs = 100;
final int dim = 30;
for (int i = 0; i < numDocs; ++i) {
Document d = new Document();
d.add(getKnnVectorField("vector", randomVector(dim)));
w.addDocument(d);
}
w.commit();
try (DirectoryReader reader = DirectoryReader.open(dir)) {
IndexSearcher searcher = new IndexSearcher(reader);
Query filter = new ThrowingBitSetQuery(new FixedBitSet(numDocs));
expectThrows(
UnsupportedOperationException.class,
() ->
searcher.search(
getKnnVectorQuery("vector", randomVector(dim), 10, filter), numDocs));
}
}
}
/** Creates a new directory and adds documents with the given vectors as kNN vector fields */
Directory getIndexStore(String field, float[]... contents) throws IOException {
Directory indexStore = newDirectory();
RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore);
for (int i = 0; i < contents.length; ++i) {
Document doc = new Document();
doc.add(getKnnVectorField(field, contents[i]));
doc.add(new StringField("id", "id" + i, Field.Store.YES));
writer.addDocument(doc);
}
// Add some documents without a vector
for (int i = 0; i < 5; i++) {
Document doc = new Document();
doc.add(new StringField("other", "value", Field.Store.NO));
writer.addDocument(doc);
}
writer.close();
return indexStore;
}
/**
* Creates a new directory and adds documents with the given vectors as kNN vector fields,
* preserving the order of the added documents.
*/
private Directory getStableIndexStore(String field, float[]... contents) throws IOException {
Directory indexStore = newDirectory();
try (IndexWriter writer = new IndexWriter(indexStore, new IndexWriterConfig())) {
for (int i = 0; i < contents.length; ++i) {
Document doc = new Document();
doc.add(getKnnVectorField(field, contents[i]));
doc.add(new StringField("id", "id" + i, Field.Store.YES));
writer.addDocument(doc);
}
// Add some documents without a vector
for (int i = 0; i < 5; i++) {
Document doc = new Document();
doc.add(new StringField("other", "value", Field.Store.NO));
writer.addDocument(doc);
}
}
return indexStore;
}
private void assertMatches(IndexSearcher searcher, Query q, int expectedMatches)
throws IOException {
ScoreDoc[] result = searcher.search(q, 1000).scoreDocs;
assertEquals(expectedMatches, result.length);
}
void assertIdMatches(IndexReader reader, String expectedId, ScoreDoc scoreDoc)
throws IOException {
String actualId = reader.storedFields().document(scoreDoc.doc).get("id");
assertEquals(expectedId, actualId);
}
/**
* A version of {@link AbstractKnnVectorQuery} that throws an error when an exact search is run.
* This allows us to check what search strategy is being used.
*/
private static class NoLiveDocsDirectoryReader extends FilterDirectoryReader {
private NoLiveDocsDirectoryReader(DirectoryReader in) throws IOException {
super(
in,
new SubReaderWrapper() {
@Override
public LeafReader wrap(LeafReader reader) {
return new NoLiveDocsLeafReader(reader);
}
});
}
@Override
protected DirectoryReader doWrapDirectoryReader(DirectoryReader in) throws IOException {
return new NoLiveDocsDirectoryReader(in);
}
@Override
public CacheHelper getReaderCacheHelper() {
return in.getReaderCacheHelper();
}
}
private static class NoLiveDocsLeafReader extends FilterLeafReader {
private NoLiveDocsLeafReader(LeafReader in) {
super(in);
}
@Override
public int numDocs() {
return 0;
}
@Override
public Bits getLiveDocs() {
return new Bits.MatchNoBits(in.maxDoc());
}
@Override
public CacheHelper getReaderCacheHelper() {
return in.getReaderCacheHelper();
}
@Override
public CacheHelper getCoreCacheHelper() {
return in.getCoreCacheHelper();
}
}
static class ThrowingBitSetQuery extends Query {
private final FixedBitSet docs;
ThrowingBitSetQuery(FixedBitSet docs) {
this.docs = docs;
}
@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
throws IOException {
return new ConstantScoreWeight(this, boost) {
@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
BitSetIterator bitSetIterator =
new BitSetIterator(docs, docs.approximateCardinality()) {
@Override
public BitSet getBitSet() {
throw new UnsupportedOperationException("reusing BitSet is not supported");
}
};
return new ConstantScoreScorer(this, score(), scoreMode, bitSetIterator);
}
@Override
public boolean isCacheable(LeafReaderContext ctx) {
return false;
}
};
}
@Override
public void visit(QueryVisitor visitor) {}
@Override
public String toString(String field) {
return "throwingBitSetQuery";
}
@Override
public boolean equals(Object other) {
return sameClassAs(other) && docs.equals(((ThrowingBitSetQuery) other).docs);
}
@Override
public int hashCode() {
return 31 * classHash() + docs.hashCode();
}
}
}

View File

@ -0,0 +1,97 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.TestVectorUtil;
public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
@Override
AbstractKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) {
return new KnnByteVectorQuery(field, floatToBytes(query), k, queryFilter);
}
@Override
AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) {
return new ThrowingKnnVectorQuery(field, floatToBytes(vec), k, query);
}
@Override
float[] randomVector(int dim) {
BytesRef bytesRef = TestVectorUtil.randomVectorBytes(dim);
float[] v = new float[bytesRef.length];
int vi = 0;
for (int i = bytesRef.offset; i < v.length; i++) {
v[vi++] = bytesRef.bytes[i];
}
return v;
}
@Override
Field getKnnVectorField(
String name, float[] vector, VectorSimilarityFunction similarityFunction) {
return new KnnVectorField(name, new BytesRef(floatToBytes(vector)), similarityFunction);
}
@Override
Field getKnnVectorField(String name, float[] vector) {
return new KnnVectorField(
name, new BytesRef(floatToBytes(vector)), VectorSimilarityFunction.EUCLIDEAN);
}
private static byte[] floatToBytes(float[] query) {
byte[] bytes = new byte[query.length];
for (int i = 0; i < query.length; i++) {
assert query[i] <= Byte.MAX_VALUE && query[i] >= Byte.MIN_VALUE && (query[i] % 1) == 0
: "float value cannot be converted to byte; provided: " + query[i];
bytes[i] = (byte) query[i];
}
return bytes;
}
public void testToString() {
AbstractKnnVectorQuery q1 = getKnnVectorQuery("f1", new float[] {0, 1}, 10);
assertEquals("KnnByteVectorQuery:f1[0,...][10]", q1.toString("ignored"));
}
@Override
VectorEncoding vectorEncoding() {
return VectorEncoding.BYTE;
}
private static class ThrowingKnnVectorQuery extends KnnByteVectorQuery {
public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter) {
super(field, target, k, filter);
}
@Override
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) {
throw new UnsupportedOperationException("exact search is not supported");
}
@Override
public String toString(String field) {
return null;
}
}
}

View File

@ -172,6 +172,17 @@ public class TestVectorUtil extends LuceneTestCase {
return v;
}
public static BytesRef randomVectorBytes(int dim) {
BytesRef v = TestUtil.randomBinaryTerm(random(), dim);
// clip at -127 to avoid overflow
for (int i = v.offset; i < v.offset + v.length; i++) {
if (v.bytes[i] == -128) {
v.bytes[i] = -127;
}
}
return v;
}
public void testBasicDotProductBytes() {
BytesRef a = new BytesRef(new byte[] {1, 2, 3});
BytesRef b = new BytesRef(new byte[] {-10, 0, 5});

View File

@ -282,15 +282,26 @@ public class TestHnswGraph extends LuceneTestCase {
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
// run some searches
NeighborQueue nn =
HnswGraphSearcher.search(
getTargetVector(),
10,
vectors.copy(),
vectorEncoding,
similarityFunction,
hnsw,
null,
Integer.MAX_VALUE);
switch (vectorEncoding) {
case BYTE -> HnswGraphSearcher.search(
getTargetByteVector(),
10,
vectors.copy(),
vectorEncoding,
similarityFunction,
hnsw,
null,
Integer.MAX_VALUE);
case FLOAT32 -> HnswGraphSearcher.search(
getTargetVector(),
10,
vectors.copy(),
vectorEncoding,
similarityFunction,
hnsw,
null,
Integer.MAX_VALUE);
};
int[] nodes = nn.nodes();
assertEquals("Number of found results is not equal to [10].", 10, nodes.length);
@ -324,15 +335,26 @@ public class TestHnswGraph extends LuceneTestCase {
// the first 10 docs must not be deleted to ensure the expected recall
Bits acceptOrds = createRandomAcceptOrds(10, vectors.size);
NeighborQueue nn =
HnswGraphSearcher.search(
getTargetVector(),
10,
vectors.copy(),
vectorEncoding,
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
switch (vectorEncoding) {
case BYTE -> HnswGraphSearcher.search(
getTargetByteVector(),
10,
vectors.copy(),
vectorEncoding,
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
case FLOAT32 -> HnswGraphSearcher.search(
getTargetVector(),
10,
vectors.copy(),
vectorEncoding,
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
};
int[] nodes = nn.nodes();
assertEquals("Number of found results is not equal to [10].", 10, nodes.length);
int sum = 0;
@ -363,15 +385,27 @@ public class TestHnswGraph extends LuceneTestCase {
// Check the search finds all accepted vectors
int numAccepted = acceptOrds.cardinality();
NeighborQueue nn =
HnswGraphSearcher.search(
getTargetVector(),
numAccepted,
vectors.copy(),
vectorEncoding,
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
switch (vectorEncoding) {
case FLOAT32 -> HnswGraphSearcher.search(
getTargetVector(),
numAccepted,
vectors.copy(),
vectorEncoding,
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
case BYTE -> HnswGraphSearcher.search(
getTargetByteVector(),
numAccepted,
vectors.copy(),
vectorEncoding,
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
};
int[] nodes = nn.nodes();
assertEquals(numAccepted, nodes.length);
for (int node : nodes) {
@ -383,6 +417,10 @@ public class TestHnswGraph extends LuceneTestCase {
return new float[] {1, 0};
}
private BytesRef getTargetByteVector() {
return new BytesRef(new byte[] {1, 0});
}
public void testSearchWithSkewedAcceptOrds() throws IOException {
int nDoc = 1000;
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
@ -432,15 +470,27 @@ public class TestHnswGraph extends LuceneTestCase {
int topK = 50;
int visitedLimit = topK + random().nextInt(5);
NeighborQueue nn =
HnswGraphSearcher.search(
getTargetVector(),
topK,
vectors.copy(),
vectorEncoding,
similarityFunction,
hnsw,
createRandomAcceptOrds(0, vectors.size),
visitedLimit);
switch (vectorEncoding) {
case FLOAT32 -> HnswGraphSearcher.search(
getTargetVector(),
topK,
vectors.copy(),
vectorEncoding,
similarityFunction,
hnsw,
createRandomAcceptOrds(0, vectors.size),
visitedLimit);
case BYTE -> HnswGraphSearcher.search(
getTargetByteVector(),
topK,
vectors.copy(),
vectorEncoding,
similarityFunction,
hnsw,
createRandomAcceptOrds(0, vectors.size),
visitedLimit);
};
assertTrue(nn.incomplete());
// The visited count shouldn't exceed the limit
assertTrue(nn.visitedCount() <= visitedLimit);
@ -664,15 +714,27 @@ public class TestHnswGraph extends LuceneTestCase {
query = randomVector(random(), dim);
}
actual =
HnswGraphSearcher.search(
query,
100,
vectors,
vectorEncoding,
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
switch (vectorEncoding) {
case BYTE -> HnswGraphSearcher.search(
bQuery,
100,
vectors,
vectorEncoding,
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
case FLOAT32 -> HnswGraphSearcher.search(
query,
100,
vectors,
vectorEncoding,
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
};
while (actual.size() > topK) {
actual.pop();
}

View File

@ -41,6 +41,7 @@ import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.Version;
/**
@ -170,6 +171,12 @@ public class TermVectorLeafReader extends LeafReader {
return null;
}
@Override
public TopDocs searchNearestVectors(
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
return null;
}
@Override
public void checkIntegrity() throws IOException {}

View File

@ -1401,6 +1401,12 @@ public class MemoryIndex {
return null;
}
@Override
public TopDocs searchNearestVectors(
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
return null;
}
@Override
public void checkIntegrity() throws IOException {
// no-op

View File

@ -28,10 +28,12 @@ import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
/** Wraps the default KnnVectorsFormat and provides additional assertions. */
public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
@ -124,7 +126,22 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
FieldInfo fi = fis.fieldInfo(field);
assert fi != null && fi.getVectorDimension() > 0;
assert fi != null
&& fi.getVectorDimension() > 0
&& fi.getVectorEncoding() == VectorEncoding.FLOAT32;
TopDocs hits = delegate.search(field, target, k, acceptDocs, visitedLimit);
assert hits != null;
assert hits.scoreDocs.length <= k;
return hits;
}
@Override
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
FieldInfo fi = fis.fieldInfo(field);
assert fi != null
&& fi.getVectorDimension() > 0
&& fi.getVectorEncoding() == VectorEncoding.BYTE;
TopDocs hits = delegate.search(field, target, k, acceptDocs, visitedLimit);
assert hits != null;
assert hits.scoreDocs.length <= k;

View File

@ -41,6 +41,7 @@ import org.apache.lucene.index.Terms;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
/**
* This is a hack to make index sorting fast, with a {@link LeafReader} that always returns merge
@ -227,6 +228,12 @@ class MergeReaderWrapper extends LeafReader {
return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
}
@Override
public TopDocs searchNearestVectors(
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
}
@Override
public int numDocs() {
return in.numDocs();

View File

@ -54,6 +54,7 @@ import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.Weight;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.Version;
import org.junit.Assert;
@ -234,6 +235,12 @@ public class QueryUtils {
return null;
}
@Override
public TopDocs searchNearestVectors(
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
return null;
}
@Override
public FieldInfos getFieldInfos() {
return FieldInfos.EMPTY;