mirror of https://github.com/apache/lucene.git
Move byte vector queries into new KnnByteVectorQuery (#12004)
This commit is contained in:
parent
9eeab8c4a6
commit
72968d30ba
|
@ -135,6 +135,10 @@ API Changes
|
|||
* GITHUB#11984: Improved TimeLimitBulkScorer to check the timeout at exponantial rate.
|
||||
(Costin Leau)
|
||||
|
||||
* GITHUB#12004: Add new KnnByteVectorQuery for querying vector fields that are encoded as BYTE. Removes the ability to
|
||||
use KnnVectorQuery against fields encoded as BYTE (Ben Trent)
|
||||
|
||||
|
||||
New Features
|
||||
---------------------
|
||||
* GITHUB#11795: Add ByteWritesTrackingDirectoryWrapper to expose metrics for bytes merged, flushed, and overall
|
||||
|
|
|
@ -276,6 +276,12 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
|||
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException {
|
||||
IndexInput bytesSlice =
|
||||
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
|
||||
|
|
|
@ -266,6 +266,12 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
|||
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException {
|
||||
IndexInput bytesSlice =
|
||||
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
|
||||
|
|
|
@ -40,6 +40,7 @@ import org.apache.lucene.store.ChecksumIndexInput;
|
|||
import org.apache.lucene.store.DataInput;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
|
@ -54,13 +55,11 @@ import org.apache.lucene.util.packed.DirectMonotonicReader;
|
|||
*/
|
||||
public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
|
||||
|
||||
private final FieldInfos fieldInfos;
|
||||
private final Map<String, FieldEntry> fields = new HashMap<>();
|
||||
private final IndexInput vectorData;
|
||||
private final IndexInput vectorIndex;
|
||||
|
||||
Lucene92HnswVectorsReader(SegmentReadState state) throws IOException {
|
||||
this.fieldInfos = state.fieldInfos;
|
||||
int versionMeta = readMetadata(state);
|
||||
boolean success = false;
|
||||
try {
|
||||
|
@ -260,18 +259,10 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
|
|||
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
|
||||
}
|
||||
|
||||
/** Get knn graph values; used for testing */
|
||||
public HnswGraph getGraph(String field) throws IOException {
|
||||
FieldInfo info = fieldInfos.fieldInfo(field);
|
||||
if (info == null) {
|
||||
throw new IllegalArgumentException("No such field '" + field + "'");
|
||||
}
|
||||
FieldEntry entry = fields.get(field);
|
||||
if (entry != null && entry.vectorIndexLength > 0) {
|
||||
return getGraph(entry);
|
||||
} else {
|
||||
return HnswGraph.EMPTY;
|
||||
}
|
||||
@Override
|
||||
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
private HnswGraph getGraph(FieldEntry entry) throws IOException {
|
||||
|
|
|
@ -41,6 +41,7 @@ import org.apache.lucene.store.DataInput;
|
|||
import org.apache.lucene.store.IOContext;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
|
@ -55,13 +56,11 @@ import org.apache.lucene.util.packed.DirectMonotonicReader;
|
|||
*/
|
||||
public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
||||
|
||||
private final FieldInfos fieldInfos;
|
||||
private final Map<String, FieldEntry> fields = new HashMap<>();
|
||||
private final IndexInput vectorData;
|
||||
private final IndexInput vectorIndex;
|
||||
|
||||
Lucene94HnswVectorsReader(SegmentReadState state) throws IOException {
|
||||
this.fieldInfos = state.fieldInfos;
|
||||
int versionMeta = readMetadata(state);
|
||||
boolean success = false;
|
||||
try {
|
||||
|
@ -249,7 +248,7 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
|||
throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
|
||||
if (fieldEntry.size() == 0) {
|
||||
if (fieldEntry.size() == 0 || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
||||
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
|
||||
}
|
||||
|
||||
|
@ -284,18 +283,44 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
|||
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
|
||||
}
|
||||
|
||||
/** Get knn graph values; used for testing */
|
||||
public HnswGraph getGraph(String field) throws IOException {
|
||||
FieldInfo info = fieldInfos.fieldInfo(field);
|
||||
if (info == null) {
|
||||
throw new IllegalArgumentException("No such field '" + field + "'");
|
||||
@Override
|
||||
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
|
||||
if (fieldEntry.size() == 0 || fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
||||
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
|
||||
}
|
||||
FieldEntry entry = fields.get(field);
|
||||
if (entry != null && entry.vectorIndexLength > 0) {
|
||||
return getGraph(entry);
|
||||
} else {
|
||||
return HnswGraph.EMPTY;
|
||||
|
||||
// bound k by total number of vectors to prevent oversizing data structures
|
||||
k = Math.min(k, fieldEntry.size());
|
||||
OffHeapVectorValues vectorValues = OffHeapVectorValues.load(fieldEntry, vectorData);
|
||||
|
||||
NeighborQueue results =
|
||||
HnswGraphSearcher.search(
|
||||
target,
|
||||
k,
|
||||
vectorValues,
|
||||
fieldEntry.vectorEncoding,
|
||||
fieldEntry.similarityFunction,
|
||||
getGraph(fieldEntry),
|
||||
vectorValues.getAcceptOrds(acceptDocs),
|
||||
visitedLimit);
|
||||
|
||||
int i = 0;
|
||||
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
|
||||
while (results.size() > 0) {
|
||||
int node = results.topNode();
|
||||
float score = results.topScore();
|
||||
results.pop();
|
||||
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(vectorValues.ordToDoc(node), score);
|
||||
}
|
||||
|
||||
TotalHits.Relation relation =
|
||||
results.incomplete()
|
||||
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
|
||||
: TotalHits.Relation.EQUAL_TO;
|
||||
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
|
||||
}
|
||||
|
||||
private HnswGraph getGraph(FieldEntry entry) throws IOException {
|
||||
|
|
|
@ -184,6 +184,12 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
return new TopDocs(new TotalHits(numVisited, relation), topScoreDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkIntegrity() throws IOException {
|
||||
IndexInput clone = dataIn.clone();
|
||||
|
|
|
@ -90,6 +90,12 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
|||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(
|
||||
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
};
|
||||
|
||||
writeField(fieldData.fieldInfo, knnVectorsReader, maxDoc);
|
||||
|
@ -185,6 +191,12 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
|||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(
|
||||
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorValues getVectorValues(String field) throws IOException {
|
||||
return MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.apache.lucene.index.SegmentWriteState;
|
|||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.NamedSPILoader;
|
||||
|
||||
/**
|
||||
|
@ -103,6 +104,12 @@ public abstract class KnnVectorsFormat implements NamedSPILoader.NamedSPI {
|
|||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(
|
||||
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {}
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.apache.lucene.search.TopDocs;
|
|||
import org.apache.lucene.search.TotalHits;
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/** Reads vectors from an index. */
|
||||
public abstract class KnnVectorsReader implements Closeable, Accountable {
|
||||
|
@ -80,6 +81,35 @@ public abstract class KnnVectorsReader implements Closeable, Accountable {
|
|||
public abstract TopDocs search(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
|
||||
|
||||
/**
|
||||
* Return the k nearest neighbor documents as determined by comparison of their vector values for
|
||||
* this field, to the given vector, by the field's similarity function. The score of each document
|
||||
* is derived from the vector similarity in a way that ensures scores are positive and that a
|
||||
* larger score corresponds to a higher ranking.
|
||||
*
|
||||
* <p>The search is allowed to be approximate, meaning the results are not guaranteed to be the
|
||||
* true k closest neighbors. For large values of k (for example when k is close to the total
|
||||
* number of documents), the search may also retrieve fewer than k documents.
|
||||
*
|
||||
* <p>The returned {@link TopDocs} will contain a {@link ScoreDoc} for each nearest neighbor, in
|
||||
* order of their similarity to the query vector (decreasing scores). The {@link TotalHits}
|
||||
* contains the number of documents visited during the search. If the search stopped early because
|
||||
* it hit {@code visitedLimit}, it is indicated through the relation {@code
|
||||
* TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO}.
|
||||
*
|
||||
* <p>The behavior is undefined if the given field doesn't have KNN vectors enabled on its {@link
|
||||
* FieldInfo}. The return value is never {@code null}.
|
||||
*
|
||||
* @param field the vector field to search
|
||||
* @param target the vector-valued query
|
||||
* @param k the number of docs to return
|
||||
* @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null}
|
||||
* if they are all allowed to match.
|
||||
* @param visitedLimit the maximum number of nodes that the search is allowed to visit
|
||||
* @return the k nearest neighbor documents, along with their (similarity-specific) scores.
|
||||
*/
|
||||
public abstract TopDocs search(
|
||||
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
|
||||
/**
|
||||
* Returns an instance optimized for merging. This instance may only be consumed in the thread
|
||||
* that called {@link #getMergeInstance()}.
|
||||
|
|
|
@ -43,6 +43,7 @@ import org.apache.lucene.store.IndexInput;
|
|||
import org.apache.lucene.store.RandomAccessInput;
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
|
@ -255,6 +256,52 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader {
|
|||
if (fieldEntry.size() == 0) {
|
||||
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
|
||||
}
|
||||
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
||||
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
|
||||
}
|
||||
|
||||
// bound k by total number of vectors to prevent oversizing data structures
|
||||
k = Math.min(k, fieldEntry.size());
|
||||
OffHeapVectorValues vectorValues = OffHeapVectorValues.load(fieldEntry, vectorData);
|
||||
|
||||
NeighborQueue results =
|
||||
HnswGraphSearcher.search(
|
||||
target,
|
||||
k,
|
||||
vectorValues,
|
||||
fieldEntry.vectorEncoding,
|
||||
fieldEntry.similarityFunction,
|
||||
getGraph(fieldEntry),
|
||||
vectorValues.getAcceptOrds(acceptDocs),
|
||||
visitedLimit);
|
||||
|
||||
int i = 0;
|
||||
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
|
||||
while (results.size() > 0) {
|
||||
int node = results.topNode();
|
||||
float score = results.topScore();
|
||||
results.pop();
|
||||
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(vectorValues.ordToDoc(node), score);
|
||||
}
|
||||
|
||||
TotalHits.Relation relation =
|
||||
results.incomplete()
|
||||
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
|
||||
: TotalHits.Relation.EQUAL_TO;
|
||||
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
|
||||
if (fieldEntry.size() == 0) {
|
||||
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
|
||||
}
|
||||
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
||||
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
|
||||
}
|
||||
|
||||
// bound k by total number of vectors to prevent oversizing data structures
|
||||
k = Math.min(k, fieldEntry.size());
|
||||
|
|
|
@ -33,10 +33,9 @@ import org.apache.lucene.index.SegmentReadState;
|
|||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.Sorter;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
|
||||
/**
|
||||
|
@ -259,12 +258,13 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
|
|||
@Override
|
||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
KnnVectorsReader knnVectorsReader = fields.get(field);
|
||||
if (knnVectorsReader == null) {
|
||||
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
|
||||
} else {
|
||||
return knnVectorsReader.search(field, target, k, acceptDocs, visitedLimit);
|
||||
}
|
||||
return fields.get(field).search(field, target, k, acceptDocs, visitedLimit);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
return fields.get(field).search(field, target, k, acceptDocs, visitedLimit);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -2598,18 +2598,25 @@ public final class CheckIndex implements Closeable {
|
|||
int docCount = 0;
|
||||
int everyNdoc = Math.max(values.size() / 64, 1);
|
||||
while (values.nextDoc() != NO_MORE_DOCS) {
|
||||
float[] vectorValue = values.vectorValue();
|
||||
// search the first maxNumSearches vectors to exercise the graph
|
||||
if (values.docID() % everyNdoc == 0) {
|
||||
TopDocs docs =
|
||||
reader
|
||||
.getVectorReader()
|
||||
.search(fieldInfo.name, vectorValue, 10, null, Integer.MAX_VALUE);
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case FLOAT32 -> reader
|
||||
.getVectorReader()
|
||||
.search(
|
||||
fieldInfo.name, values.vectorValue(), 10, null, Integer.MAX_VALUE);
|
||||
case BYTE -> reader
|
||||
.getVectorReader()
|
||||
.search(
|
||||
fieldInfo.name, values.binaryValue(), 10, null, Integer.MAX_VALUE);
|
||||
};
|
||||
if (docs.scoreDocs.length == 0) {
|
||||
throw new CheckIndexException(
|
||||
"Field \"" + fieldInfo.name + "\" failed to search k nearest neighbors");
|
||||
}
|
||||
}
|
||||
float[] vectorValue = values.vectorValue();
|
||||
int valueLength = vectorValue.length;
|
||||
if (valueLength != dimension) {
|
||||
throw new CheckIndexException(
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.apache.lucene.codecs.StoredFieldsReader;
|
|||
import org.apache.lucene.codecs.TermVectorsReader;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/** LeafReader implemented by codec APIs. */
|
||||
public abstract class CodecReader extends LeafReader {
|
||||
|
@ -238,6 +239,19 @@ public abstract class CodecReader extends LeafReader {
|
|||
return getVectorReader().search(field, target, k, acceptDocs, visitedLimit);
|
||||
}
|
||||
|
||||
@Override
|
||||
public final TopDocs searchNearestVectors(
|
||||
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
|
||||
ensureOpen();
|
||||
FieldInfo fi = getFieldInfos().fieldInfo(field);
|
||||
if (fi == null || fi.getVectorDimension() == 0) {
|
||||
// Field does not exist or does not index vectors
|
||||
return null;
|
||||
}
|
||||
|
||||
return getVectorReader().search(field, target, k, acceptDocs, visitedLimit);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doClose() throws IOException {}
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ package org.apache.lucene.index;
|
|||
import java.io.IOException;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
abstract class DocValuesLeafReader extends LeafReader {
|
||||
@Override
|
||||
|
@ -58,6 +59,12 @@ abstract class DocValuesLeafReader extends LeafReader {
|
|||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(
|
||||
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public final void checkIntegrity() throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
|
|
|
@ -357,6 +357,12 @@ public abstract class FilterLeafReader extends LeafReader {
|
|||
return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(
|
||||
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
|
||||
return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TermVectors termVectors() throws IOException {
|
||||
ensureOpen();
|
||||
|
|
|
@ -21,6 +21,7 @@ import org.apache.lucene.search.ScoreDoc;
|
|||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/**
|
||||
* {@code LeafReader} is an abstract class, providing an interface for accessing an index. Search of
|
||||
|
@ -235,6 +236,34 @@ public abstract class LeafReader extends IndexReader {
|
|||
public abstract TopDocs searchNearestVectors(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
|
||||
|
||||
/**
|
||||
* Return the k nearest neighbor documents as determined by comparison of their vector values for
|
||||
* this field, to the given vector, by the field's similarity function. The score of each document
|
||||
* is derived from the vector similarity in a way that ensures scores are positive and that a
|
||||
* larger score corresponds to a higher ranking.
|
||||
*
|
||||
* <p>The search is allowed to be approximate, meaning the results are not guaranteed to be the
|
||||
* true k closest neighbors. For large values of k (for example when k is close to the total
|
||||
* number of documents), the search may also retrieve fewer than k documents.
|
||||
*
|
||||
* <p>The returned {@link TopDocs} will contain a {@link ScoreDoc} for each nearest neighbor,
|
||||
* sorted in order of their similarity to the query vector (decreasing scores). The {@link
|
||||
* TotalHits} contains the number of documents visited during the search. If the search stopped
|
||||
* early because it hit {@code visitedLimit}, it is indicated through the relation {@code
|
||||
* TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO}.
|
||||
*
|
||||
* @param field the vector field to search
|
||||
* @param target the vector-valued query
|
||||
* @param k the number of docs to return
|
||||
* @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null}
|
||||
* if they are all allowed to match.
|
||||
* @param visitedLimit the maximum number of nodes that the search is allowed to visit
|
||||
* @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores.
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public abstract TopDocs searchNearestVectors(
|
||||
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
|
||||
|
||||
/**
|
||||
* Get the {@link FieldInfos} describing all fields in this reader.
|
||||
*
|
||||
|
|
|
@ -29,6 +29,7 @@ import java.util.TreeMap;
|
|||
import org.apache.lucene.search.Sort;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.Version;
|
||||
|
||||
/**
|
||||
|
@ -418,6 +419,17 @@ public class ParallelLeafReader extends LeafReader {
|
|||
: reader.searchNearestVectors(fieldName, target, k, acceptDocs, visitedLimit);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(
|
||||
String fieldName, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
ensureOpen();
|
||||
LeafReader reader = fieldToReader.get(fieldName);
|
||||
return reader == null
|
||||
? null
|
||||
: reader.searchNearestVectors(fieldName, target, k, acceptDocs, visitedLimit);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkIntegrity() throws IOException {
|
||||
ensureOpen();
|
||||
|
|
|
@ -30,6 +30,7 @@ import org.apache.lucene.codecs.StoredFieldsReader;
|
|||
import org.apache.lucene.codecs.TermVectorsReader;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/**
|
||||
* Wraps arbitrary readers for merging. Note that this can cause slow and memory-intensive merges.
|
||||
|
@ -173,6 +174,12 @@ public final class SlowCodecReaderWrapper {
|
|||
return reader.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
return reader.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkIntegrity() {
|
||||
// We already checkIntegrity the entire reader up front
|
||||
|
|
|
@ -476,6 +476,12 @@ public final class SortingCodecReader extends FilterCodecReader {
|
|||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(
|
||||
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() throws IOException {
|
||||
delegate.close();
|
||||
|
|
|
@ -0,0 +1,410 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.search;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Comparator;
|
||||
import java.util.Objects;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.BitSetIterator;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
||||
/**
|
||||
* Uses {@link KnnVectorsReader#search} to perform nearest neighbour search.
|
||||
*
|
||||
* <p>This query also allows for performing a kNN search subject to a filter. In this case, it first
|
||||
* executes the filter for each leaf, then chooses a strategy dynamically:
|
||||
*
|
||||
* <ul>
|
||||
* <li>If the filter cost is less than k, just execute an exact search
|
||||
* <li>Otherwise run a kNN search subject to the filter
|
||||
* <li>If the kNN search visits too many vectors without completing, stop and run an exact search
|
||||
* </ul>
|
||||
*/
|
||||
abstract class AbstractKnnVectorQuery extends Query {
|
||||
|
||||
private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
|
||||
|
||||
protected final String field;
|
||||
protected final int k;
|
||||
private final Query filter;
|
||||
|
||||
public AbstractKnnVectorQuery(String field, int k, Query filter) {
|
||||
this.field = field;
|
||||
this.k = k;
|
||||
if (k < 1) {
|
||||
throw new IllegalArgumentException("k must be at least 1, got: " + k);
|
||||
}
|
||||
this.filter = filter;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
|
||||
IndexReader reader = indexSearcher.getIndexReader();
|
||||
TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
|
||||
|
||||
Weight filterWeight = null;
|
||||
if (filter != null) {
|
||||
BooleanQuery booleanQuery =
|
||||
new BooleanQuery.Builder()
|
||||
.add(filter, BooleanClause.Occur.FILTER)
|
||||
.add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER)
|
||||
.build();
|
||||
Query rewritten = indexSearcher.rewrite(booleanQuery);
|
||||
filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
|
||||
}
|
||||
|
||||
for (LeafReaderContext ctx : reader.leaves()) {
|
||||
TopDocs results = searchLeaf(ctx, filterWeight);
|
||||
if (ctx.docBase > 0) {
|
||||
for (ScoreDoc scoreDoc : results.scoreDocs) {
|
||||
scoreDoc.doc += ctx.docBase;
|
||||
}
|
||||
}
|
||||
perLeafResults[ctx.ord] = results;
|
||||
}
|
||||
// Merge sort the results
|
||||
TopDocs topK = TopDocs.merge(k, perLeafResults);
|
||||
if (topK.scoreDocs.length == 0) {
|
||||
return new MatchNoDocsQuery();
|
||||
}
|
||||
return createRewrittenQuery(reader, topK);
|
||||
}
|
||||
|
||||
private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight) throws IOException {
|
||||
Bits liveDocs = ctx.reader().getLiveDocs();
|
||||
int maxDoc = ctx.reader().maxDoc();
|
||||
|
||||
if (filterWeight == null) {
|
||||
return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE);
|
||||
}
|
||||
|
||||
Scorer scorer = filterWeight.scorer(ctx);
|
||||
if (scorer == null) {
|
||||
return NO_RESULTS;
|
||||
}
|
||||
|
||||
BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, maxDoc);
|
||||
int cost = acceptDocs.cardinality();
|
||||
|
||||
if (cost <= k) {
|
||||
// If there are <= k possible matches, short-circuit and perform exact search, since HNSW
|
||||
// must always visit at least k documents
|
||||
return exactSearch(ctx, new BitSetIterator(acceptDocs, cost));
|
||||
}
|
||||
|
||||
// Perform the approximate kNN search
|
||||
TopDocs results = approximateSearch(ctx, acceptDocs, cost);
|
||||
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) {
|
||||
return results;
|
||||
} else {
|
||||
// We stopped the kNN search because it visited too many nodes, so fall back to exact search
|
||||
return exactSearch(ctx, new BitSetIterator(acceptDocs, cost));
|
||||
}
|
||||
}
|
||||
|
||||
private BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc)
|
||||
throws IOException {
|
||||
if (liveDocs == null && iterator instanceof BitSetIterator bitSetIterator) {
|
||||
// If we already have a BitSet and no deletions, reuse the BitSet
|
||||
return bitSetIterator.getBitSet();
|
||||
} else {
|
||||
// Create a new BitSet from matching and live docs
|
||||
FilteredDocIdSetIterator filterIterator =
|
||||
new FilteredDocIdSetIterator(iterator) {
|
||||
@Override
|
||||
protected boolean match(int doc) {
|
||||
return liveDocs == null || liveDocs.get(doc);
|
||||
}
|
||||
};
|
||||
return BitSet.of(filterIterator, maxDoc);
|
||||
}
|
||||
}
|
||||
|
||||
protected abstract TopDocs approximateSearch(
|
||||
LeafReaderContext context, Bits acceptDocs, int visitedLimit) throws IOException;
|
||||
|
||||
abstract VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi)
|
||||
throws IOException;
|
||||
|
||||
// We allow this to be overridden so that tests can check what search strategy is used
|
||||
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
|
||||
throws IOException {
|
||||
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
|
||||
if (fi == null || fi.getVectorDimension() == 0) {
|
||||
// The field does not exist or does not index vectors
|
||||
return NO_RESULTS;
|
||||
}
|
||||
|
||||
VectorScorer vectorScorer = createVectorScorer(context, fi);
|
||||
HitQueue queue = new HitQueue(k, true);
|
||||
ScoreDoc topDoc = queue.top();
|
||||
int doc;
|
||||
while ((doc = acceptIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
|
||||
boolean advanced = vectorScorer.advanceExact(doc);
|
||||
assert advanced;
|
||||
|
||||
float score = vectorScorer.score();
|
||||
if (score > topDoc.score) {
|
||||
topDoc.score = score;
|
||||
topDoc.doc = doc;
|
||||
topDoc = queue.updateTop();
|
||||
}
|
||||
}
|
||||
|
||||
// Remove any remaining sentinel values
|
||||
while (queue.size() > 0 && queue.top().score < 0) {
|
||||
queue.pop();
|
||||
}
|
||||
|
||||
ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()];
|
||||
for (int i = topScoreDocs.length - 1; i >= 0; i--) {
|
||||
topScoreDocs[i] = queue.pop();
|
||||
}
|
||||
|
||||
TotalHits totalHits = new TotalHits(acceptIterator.cost(), TotalHits.Relation.EQUAL_TO);
|
||||
return new TopDocs(totalHits, topScoreDocs);
|
||||
}
|
||||
|
||||
private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
|
||||
int len = topK.scoreDocs.length;
|
||||
Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
|
||||
int[] docs = new int[len];
|
||||
float[] scores = new float[len];
|
||||
for (int i = 0; i < len; i++) {
|
||||
docs[i] = topK.scoreDocs[i].doc;
|
||||
scores[i] = topK.scoreDocs[i].score;
|
||||
}
|
||||
int[] segmentStarts = findSegmentStarts(reader, docs);
|
||||
return new DocAndScoreQuery(k, docs, scores, segmentStarts, reader.getContext().id());
|
||||
}
|
||||
|
||||
private int[] findSegmentStarts(IndexReader reader, int[] docs) {
|
||||
int[] starts = new int[reader.leaves().size() + 1];
|
||||
starts[starts.length - 1] = docs.length;
|
||||
if (starts.length == 2) {
|
||||
return starts;
|
||||
}
|
||||
int resultIndex = 0;
|
||||
for (int i = 1; i < starts.length - 1; i++) {
|
||||
int upper = reader.leaves().get(i).docBase;
|
||||
resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper);
|
||||
if (resultIndex < 0) {
|
||||
resultIndex = -1 - resultIndex;
|
||||
}
|
||||
starts[i] = resultIndex;
|
||||
}
|
||||
return starts;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(QueryVisitor visitor) {
|
||||
if (visitor.acceptField(field)) {
|
||||
visitor.visitLeaf(this);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
AbstractKnnVectorQuery that = (AbstractKnnVectorQuery) o;
|
||||
return k == that.k && Objects.equals(field, that.field) && Objects.equals(filter, that.filter);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(field, k, filter);
|
||||
}
|
||||
|
||||
/** Caches the results of a KnnVector search: a list of docs and their scores */
|
||||
static class DocAndScoreQuery extends Query {
|
||||
|
||||
private final int k;
|
||||
private final int[] docs;
|
||||
private final float[] scores;
|
||||
private final int[] segmentStarts;
|
||||
private final Object contextIdentity;
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
*
|
||||
* @param k the number of documents requested
|
||||
* @param docs the global docids of documents that match, in ascending order
|
||||
* @param scores the scores of the matching documents
|
||||
* @param segmentStarts the indexes in docs and scores corresponding to the first matching
|
||||
* document in each segment. If a segment has no matching documents, it should be assigned
|
||||
* the index of the next segment that does. There should be a final entry that is always
|
||||
* docs.length-1.
|
||||
* @param contextIdentity an object identifying the reader context that was used to build this
|
||||
* query
|
||||
*/
|
||||
DocAndScoreQuery(
|
||||
int k, int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) {
|
||||
this.k = k;
|
||||
this.docs = docs;
|
||||
this.scores = scores;
|
||||
this.segmentStarts = segmentStarts;
|
||||
this.contextIdentity = contextIdentity;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
|
||||
throws IOException {
|
||||
if (searcher.getIndexReader().getContext().id() != contextIdentity) {
|
||||
throw new IllegalStateException("This DocAndScore query was created by a different reader");
|
||||
}
|
||||
return new Weight(this) {
|
||||
@Override
|
||||
public Explanation explain(LeafReaderContext context, int doc) {
|
||||
int found = Arrays.binarySearch(docs, doc + context.docBase);
|
||||
if (found < 0) {
|
||||
return Explanation.noMatch("not in top " + k);
|
||||
}
|
||||
return Explanation.match(scores[found] * boost, "within top " + k);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Scorer scorer(LeafReaderContext context) {
|
||||
|
||||
return new Scorer(this) {
|
||||
final int lower = segmentStarts[context.ord];
|
||||
final int upper = segmentStarts[context.ord + 1];
|
||||
int upTo = -1;
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return new DocIdSetIterator() {
|
||||
@Override
|
||||
public int docID() {
|
||||
return docIdNoShadow();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() {
|
||||
if (upTo == -1) {
|
||||
upTo = lower;
|
||||
} else {
|
||||
++upTo;
|
||||
}
|
||||
return docIdNoShadow();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
return slowAdvance(target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long cost() {
|
||||
return upper - lower;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public float getMaxScore(int docId) {
|
||||
docId += context.docBase;
|
||||
float maxScore = 0;
|
||||
for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docId; idx++) {
|
||||
maxScore = Math.max(maxScore, scores[idx]);
|
||||
}
|
||||
return maxScore * boost;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score() {
|
||||
return scores[upTo] * boost;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advanceShallow(int docid) {
|
||||
int start = Math.max(upTo, lower);
|
||||
int docidIndex = Arrays.binarySearch(docs, start, upper, docid + context.docBase);
|
||||
if (docidIndex < 0) {
|
||||
docidIndex = -1 - docidIndex;
|
||||
}
|
||||
if (docidIndex >= upper) {
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
return docs[docidIndex];
|
||||
}
|
||||
|
||||
/**
|
||||
* move the implementation of docID() into a differently-named method so we can call it
|
||||
* from DocIDSetIterator.docID() even though this class is anonymous
|
||||
*
|
||||
* @return the current docid
|
||||
*/
|
||||
private int docIdNoShadow() {
|
||||
if (upTo == -1) {
|
||||
return -1;
|
||||
}
|
||||
if (upTo >= upper) {
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
return docs[upTo] - context.docBase;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return docIdNoShadow();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isCacheable(LeafReaderContext ctx) {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString(String field) {
|
||||
return "DocAndScore[" + k + "]";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(QueryVisitor visitor) {
|
||||
visitor.visitLeaf(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (sameClassAs(obj) == false) {
|
||||
return false;
|
||||
}
|
||||
return contextIdentity == ((DocAndScoreQuery) obj).contextIdentity
|
||||
&& Arrays.equals(docs, ((DocAndScoreQuery) obj).docs)
|
||||
&& Arrays.equals(scores, ((DocAndScoreQuery) obj).scores);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(
|
||||
classHash(), contextIdentity, Arrays.hashCode(docs), Arrays.hashCode(scores));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,116 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.search;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.document.KnnVectorField;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/**
|
||||
* Uses {@link KnnVectorsReader#search(String, BytesRef, int, Bits, int)} to perform nearest
|
||||
* neighbour search.
|
||||
*
|
||||
* <p>This query also allows for performing a kNN search subject to a filter. In this case, it first
|
||||
* executes the filter for each leaf, then chooses a strategy dynamically:
|
||||
*
|
||||
* <ul>
|
||||
* <li>If the filter cost is less than k, just execute an exact search
|
||||
* <li>Otherwise run a kNN search subject to the filter
|
||||
* <li>If the kNN search visits too many vectors without completing, stop and run an exact search
|
||||
* </ul>
|
||||
*/
|
||||
public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
|
||||
|
||||
private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
|
||||
|
||||
private final BytesRef target;
|
||||
|
||||
/**
|
||||
* Find the <code>k</code> nearest documents to the target vector according to the vectors in the
|
||||
* given field. <code>target</code> vector.
|
||||
*
|
||||
* @param field a field that has been indexed as a {@link KnnVectorField}.
|
||||
* @param target the target of the search
|
||||
* @param k the number of documents to find
|
||||
* @throws IllegalArgumentException if <code>k</code> is less than 1
|
||||
*/
|
||||
public KnnByteVectorQuery(String field, byte[] target, int k) {
|
||||
this(field, target, k, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the <code>k</code> nearest documents to the target vector according to the vectors in the
|
||||
* given field. <code>target</code> vector.
|
||||
*
|
||||
* @param field a field that has been indexed as a {@link KnnVectorField}.
|
||||
* @param target the target of the search
|
||||
* @param k the number of documents to find
|
||||
* @param filter a filter applied before the vector search
|
||||
* @throws IllegalArgumentException if <code>k</code> is less than 1
|
||||
*/
|
||||
public KnnByteVectorQuery(String field, byte[] target, int k, Query filter) {
|
||||
super(field, k, filter);
|
||||
this.target = new BytesRef(target);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
TopDocs results =
|
||||
context.reader().searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
|
||||
return results != null ? results : NO_RESULTS;
|
||||
}
|
||||
|
||||
@Override
|
||||
VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException {
|
||||
if (fi.getVectorEncoding() != VectorEncoding.BYTE) {
|
||||
return null;
|
||||
}
|
||||
return VectorScorer.create(context, fi, target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString(String field) {
|
||||
return getClass().getSimpleName()
|
||||
+ ":"
|
||||
+ this.field
|
||||
+ "["
|
||||
+ target.bytes[target.offset]
|
||||
+ ",...]["
|
||||
+ k
|
||||
+ "]";
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (super.equals(o) == false) return false;
|
||||
KnnByteVectorQuery that = (KnnByteVectorQuery) o;
|
||||
return Objects.equals(target, that.target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(super.hashCode(), target);
|
||||
}
|
||||
}
|
|
@ -16,23 +16,18 @@
|
|||
*/
|
||||
package org.apache.lucene.search;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Comparator;
|
||||
import java.util.Objects;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.document.KnnVectorField;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.BitSetIterator;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
||||
/**
|
||||
* Uses {@link KnnVectorsReader#search} to perform nearest neighbour search.
|
||||
* Uses {@link KnnVectorsReader#search(String, float[], int, Bits, int)} to perform nearest
|
||||
* neighbour search.
|
||||
*
|
||||
* <p>This query also allows for performing a kNN search subject to a filter. In this case, it first
|
||||
* executes the filter for each leaf, then chooses a strategy dynamically:
|
||||
|
@ -43,14 +38,11 @@ import org.apache.lucene.util.Bits;
|
|||
* <li>If the kNN search visits too many vectors without completing, stop and run an exact search
|
||||
* </ul>
|
||||
*/
|
||||
public class KnnVectorQuery extends Query {
|
||||
public class KnnVectorQuery extends AbstractKnnVectorQuery {
|
||||
|
||||
private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
|
||||
|
||||
private final String field;
|
||||
private final float[] target;
|
||||
private final int k;
|
||||
private final Query filter;
|
||||
|
||||
/**
|
||||
* Find the <code>k</code> nearest documents to the target vector according to the vectors in the
|
||||
|
@ -76,173 +68,24 @@ public class KnnVectorQuery extends Query {
|
|||
* @throws IllegalArgumentException if <code>k</code> is less than 1
|
||||
*/
|
||||
public KnnVectorQuery(String field, float[] target, int k, Query filter) {
|
||||
this.field = field;
|
||||
super(field, k, filter);
|
||||
this.target = target;
|
||||
this.k = k;
|
||||
if (k < 1) {
|
||||
throw new IllegalArgumentException("k must be at least 1, got: " + k);
|
||||
}
|
||||
this.filter = filter;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
|
||||
IndexReader reader = indexSearcher.getIndexReader();
|
||||
TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
|
||||
|
||||
Weight filterWeight = null;
|
||||
if (filter != null) {
|
||||
BooleanQuery booleanQuery =
|
||||
new BooleanQuery.Builder()
|
||||
.add(filter, BooleanClause.Occur.FILTER)
|
||||
.add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER)
|
||||
.build();
|
||||
Query rewritten = indexSearcher.rewrite(booleanQuery);
|
||||
filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
|
||||
}
|
||||
|
||||
for (LeafReaderContext ctx : reader.leaves()) {
|
||||
TopDocs results = searchLeaf(ctx, filterWeight);
|
||||
if (ctx.docBase > 0) {
|
||||
for (ScoreDoc scoreDoc : results.scoreDocs) {
|
||||
scoreDoc.doc += ctx.docBase;
|
||||
}
|
||||
}
|
||||
perLeafResults[ctx.ord] = results;
|
||||
}
|
||||
// Merge sort the results
|
||||
TopDocs topK = TopDocs.merge(k, perLeafResults);
|
||||
if (topK.scoreDocs.length == 0) {
|
||||
return new MatchNoDocsQuery();
|
||||
}
|
||||
return createRewrittenQuery(reader, topK);
|
||||
}
|
||||
|
||||
private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight) throws IOException {
|
||||
Bits liveDocs = ctx.reader().getLiveDocs();
|
||||
int maxDoc = ctx.reader().maxDoc();
|
||||
|
||||
if (filterWeight == null) {
|
||||
return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE);
|
||||
}
|
||||
|
||||
Scorer scorer = filterWeight.scorer(ctx);
|
||||
if (scorer == null) {
|
||||
return NO_RESULTS;
|
||||
}
|
||||
|
||||
BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, maxDoc);
|
||||
int cost = acceptDocs.cardinality();
|
||||
|
||||
if (cost <= k) {
|
||||
// If there are <= k possible matches, short-circuit and perform exact search, since HNSW
|
||||
// must always visit at least k documents
|
||||
return exactSearch(ctx, new BitSetIterator(acceptDocs, cost));
|
||||
}
|
||||
|
||||
// Perform the approximate kNN search
|
||||
TopDocs results = approximateSearch(ctx, acceptDocs, cost);
|
||||
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) {
|
||||
return results;
|
||||
} else {
|
||||
// We stopped the kNN search because it visited too many nodes, so fall back to exact search
|
||||
return exactSearch(ctx, new BitSetIterator(acceptDocs, cost));
|
||||
}
|
||||
}
|
||||
|
||||
private BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc)
|
||||
throws IOException {
|
||||
if (liveDocs == null && iterator instanceof BitSetIterator bitSetIterator) {
|
||||
// If we already have a BitSet and no deletions, reuse the BitSet
|
||||
return bitSetIterator.getBitSet();
|
||||
} else {
|
||||
// Create a new BitSet from matching and live docs
|
||||
FilteredDocIdSetIterator filterIterator =
|
||||
new FilteredDocIdSetIterator(iterator) {
|
||||
@Override
|
||||
protected boolean match(int doc) {
|
||||
return liveDocs == null || liveDocs.get(doc);
|
||||
}
|
||||
};
|
||||
return BitSet.of(filterIterator, maxDoc);
|
||||
}
|
||||
}
|
||||
|
||||
private TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit)
|
||||
protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
TopDocs results =
|
||||
context.reader().searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
|
||||
return results != null ? results : NO_RESULTS;
|
||||
}
|
||||
|
||||
// We allow this to be overridden so that tests can check what search strategy is used
|
||||
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
|
||||
throws IOException {
|
||||
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
|
||||
if (fi == null || fi.getVectorDimension() == 0) {
|
||||
// The field does not exist or does not index vectors
|
||||
return NO_RESULTS;
|
||||
@Override
|
||||
VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException {
|
||||
if (fi.getVectorEncoding() != VectorEncoding.FLOAT32) {
|
||||
return null;
|
||||
}
|
||||
|
||||
VectorScorer vectorScorer = VectorScorer.create(context, fi, target);
|
||||
HitQueue queue = new HitQueue(k, true);
|
||||
ScoreDoc topDoc = queue.top();
|
||||
int doc;
|
||||
while ((doc = acceptIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
|
||||
boolean advanced = vectorScorer.advanceExact(doc);
|
||||
assert advanced;
|
||||
|
||||
float score = vectorScorer.score();
|
||||
if (score > topDoc.score) {
|
||||
topDoc.score = score;
|
||||
topDoc.doc = doc;
|
||||
topDoc = queue.updateTop();
|
||||
}
|
||||
}
|
||||
|
||||
// Remove any remaining sentinel values
|
||||
while (queue.size() > 0 && queue.top().score < 0) {
|
||||
queue.pop();
|
||||
}
|
||||
|
||||
ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()];
|
||||
for (int i = topScoreDocs.length - 1; i >= 0; i--) {
|
||||
topScoreDocs[i] = queue.pop();
|
||||
}
|
||||
|
||||
TotalHits totalHits = new TotalHits(acceptIterator.cost(), TotalHits.Relation.EQUAL_TO);
|
||||
return new TopDocs(totalHits, topScoreDocs);
|
||||
}
|
||||
|
||||
private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
|
||||
int len = topK.scoreDocs.length;
|
||||
Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
|
||||
int[] docs = new int[len];
|
||||
float[] scores = new float[len];
|
||||
for (int i = 0; i < len; i++) {
|
||||
docs[i] = topK.scoreDocs[i].doc;
|
||||
scores[i] = topK.scoreDocs[i].score;
|
||||
}
|
||||
int[] segmentStarts = findSegmentStarts(reader, docs);
|
||||
return new DocAndScoreQuery(k, docs, scores, segmentStarts, reader.getContext().id());
|
||||
}
|
||||
|
||||
private int[] findSegmentStarts(IndexReader reader, int[] docs) {
|
||||
int[] starts = new int[reader.leaves().size() + 1];
|
||||
starts[starts.length - 1] = docs.length;
|
||||
if (starts.length == 2) {
|
||||
return starts;
|
||||
}
|
||||
int resultIndex = 0;
|
||||
for (int i = 1; i < starts.length - 1; i++) {
|
||||
int upper = reader.leaves().get(i).docBase;
|
||||
resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper);
|
||||
if (resultIndex < 0) {
|
||||
resultIndex = -1 - resultIndex;
|
||||
}
|
||||
starts[i] = resultIndex;
|
||||
}
|
||||
return starts;
|
||||
return VectorScorer.create(context, fi, target);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -251,195 +94,17 @@ public class KnnVectorQuery extends Query {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void visit(QueryVisitor visitor) {
|
||||
if (visitor.acceptField(field)) {
|
||||
visitor.visitLeaf(this);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (sameClassAs(obj) == false) {
|
||||
return false;
|
||||
}
|
||||
return ((KnnVectorQuery) obj).k == k
|
||||
&& ((KnnVectorQuery) obj).field.equals(field)
|
||||
&& Arrays.equals(((KnnVectorQuery) obj).target, target)
|
||||
&& Objects.equals(filter, ((KnnVectorQuery) obj).filter);
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (super.equals(o) == false) return false;
|
||||
KnnVectorQuery that = (KnnVectorQuery) o;
|
||||
return Arrays.equals(target, that.target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(classHash(), field, k, Arrays.hashCode(target), filter);
|
||||
}
|
||||
|
||||
/** Caches the results of a KnnVector search: a list of docs and their scores */
|
||||
static class DocAndScoreQuery extends Query {
|
||||
|
||||
private final int k;
|
||||
private final int[] docs;
|
||||
private final float[] scores;
|
||||
private final int[] segmentStarts;
|
||||
private final Object contextIdentity;
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
*
|
||||
* @param k the number of documents requested
|
||||
* @param docs the global docids of documents that match, in ascending order
|
||||
* @param scores the scores of the matching documents
|
||||
* @param segmentStarts the indexes in docs and scores corresponding to the first matching
|
||||
* document in each segment. If a segment has no matching documents, it should be assigned
|
||||
* the index of the next segment that does. There should be a final entry that is always
|
||||
* docs.length-1.
|
||||
* @param contextIdentity an object identifying the reader context that was used to build this
|
||||
* query
|
||||
*/
|
||||
DocAndScoreQuery(
|
||||
int k, int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) {
|
||||
this.k = k;
|
||||
this.docs = docs;
|
||||
this.scores = scores;
|
||||
this.segmentStarts = segmentStarts;
|
||||
this.contextIdentity = contextIdentity;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
|
||||
throws IOException {
|
||||
if (searcher.getIndexReader().getContext().id() != contextIdentity) {
|
||||
throw new IllegalStateException("This DocAndScore query was created by a different reader");
|
||||
}
|
||||
return new Weight(this) {
|
||||
@Override
|
||||
public Explanation explain(LeafReaderContext context, int doc) {
|
||||
int found = Arrays.binarySearch(docs, doc + context.docBase);
|
||||
if (found < 0) {
|
||||
return Explanation.noMatch("not in top " + k);
|
||||
}
|
||||
return Explanation.match(scores[found] * boost, "within top " + k);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Scorer scorer(LeafReaderContext context) {
|
||||
|
||||
return new Scorer(this) {
|
||||
final int lower = segmentStarts[context.ord];
|
||||
final int upper = segmentStarts[context.ord + 1];
|
||||
int upTo = -1;
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return new DocIdSetIterator() {
|
||||
@Override
|
||||
public int docID() {
|
||||
return docIdNoShadow();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() {
|
||||
if (upTo == -1) {
|
||||
upTo = lower;
|
||||
} else {
|
||||
++upTo;
|
||||
}
|
||||
return docIdNoShadow();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
return slowAdvance(target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long cost() {
|
||||
return upper - lower;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public float getMaxScore(int docId) {
|
||||
docId += context.docBase;
|
||||
float maxScore = 0;
|
||||
for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docId; idx++) {
|
||||
maxScore = Math.max(maxScore, scores[idx]);
|
||||
}
|
||||
return maxScore * boost;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score() {
|
||||
return scores[upTo] * boost;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advanceShallow(int docid) {
|
||||
int start = Math.max(upTo, lower);
|
||||
int docidIndex = Arrays.binarySearch(docs, start, upper, docid + context.docBase);
|
||||
if (docidIndex < 0) {
|
||||
docidIndex = -1 - docidIndex;
|
||||
}
|
||||
if (docidIndex >= upper) {
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
return docs[docidIndex];
|
||||
}
|
||||
|
||||
/**
|
||||
* move the implementation of docID() into a differently-named method so we can call it
|
||||
* from DocIDSetIterator.docID() even though this class is anonymous
|
||||
*
|
||||
* @return the current docid
|
||||
*/
|
||||
private int docIdNoShadow() {
|
||||
if (upTo == -1) {
|
||||
return -1;
|
||||
}
|
||||
if (upTo >= upper) {
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
return docs[upTo] - context.docBase;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return docIdNoShadow();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isCacheable(LeafReaderContext ctx) {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString(String field) {
|
||||
return "DocAndScore[" + k + "]";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(QueryVisitor visitor) {
|
||||
visitor.visitLeaf(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (sameClassAs(obj) == false) {
|
||||
return false;
|
||||
}
|
||||
return contextIdentity == ((DocAndScoreQuery) obj).contextIdentity
|
||||
&& Arrays.equals(docs, ((DocAndScoreQuery) obj).docs)
|
||||
&& Arrays.equals(scores, ((DocAndScoreQuery) obj).scores);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(
|
||||
classHash(), contextIdentity, Arrays.hashCode(docs), Arrays.hashCode(scores));
|
||||
}
|
||||
int result = super.hashCode();
|
||||
result = 31 * result + Arrays.hashCode(target);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,7 +22,6 @@ import org.apache.lucene.index.LeafReaderContext;
|
|||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
|
||||
/**
|
||||
* Computes the similarity score between a given query vector and different document vectors. This
|
||||
|
@ -40,14 +39,18 @@ abstract class VectorScorer {
|
|||
* @param fi the FieldInfo for the field containing document vectors
|
||||
* @param query the query vector to compute the similarity for
|
||||
*/
|
||||
static VectorScorer create(LeafReaderContext context, FieldInfo fi, float[] query)
|
||||
static FloatVectorScorer create(LeafReaderContext context, FieldInfo fi, float[] query)
|
||||
throws IOException {
|
||||
VectorValues values = context.reader().getVectorValues(fi.name);
|
||||
final VectorSimilarityFunction similarity = fi.getVectorSimilarityFunction();
|
||||
return new FloatVectorScorer(values, query, similarity);
|
||||
}
|
||||
|
||||
static ByteVectorScorer create(LeafReaderContext context, FieldInfo fi, BytesRef query)
|
||||
throws IOException {
|
||||
VectorValues values = context.reader().getVectorValues(fi.name);
|
||||
VectorSimilarityFunction similarity = fi.getVectorSimilarityFunction();
|
||||
return switch (fi.getVectorEncoding()) {
|
||||
case BYTE -> new ByteVectorScorer(values, query, similarity);
|
||||
case FLOAT32 -> new FloatVectorScorer(values, query, similarity);
|
||||
};
|
||||
return new ByteVectorScorer(values, query, similarity);
|
||||
}
|
||||
|
||||
VectorScorer(VectorValues values, VectorSimilarityFunction similarity) {
|
||||
|
@ -74,9 +77,9 @@ abstract class VectorScorer {
|
|||
private final BytesRef query;
|
||||
|
||||
protected ByteVectorScorer(
|
||||
VectorValues values, float[] query, VectorSimilarityFunction similarity) {
|
||||
VectorValues values, BytesRef query, VectorSimilarityFunction similarity) {
|
||||
super(values, similarity);
|
||||
this.query = VectorUtil.toBytesRef(query);
|
||||
this.query = query;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
import static org.apache.lucene.util.VectorUtil.toBytesRef;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
|
@ -96,17 +95,6 @@ public class HnswGraphSearcher<T> {
|
|||
+ " differs from field dimension: "
|
||||
+ vectors.dimension());
|
||||
}
|
||||
if (vectorEncoding == VectorEncoding.BYTE) {
|
||||
return search(
|
||||
toBytesRef(query),
|
||||
topK,
|
||||
vectors,
|
||||
vectorEncoding,
|
||||
similarityFunction,
|
||||
graph,
|
||||
acceptOrds,
|
||||
visitedLimit);
|
||||
}
|
||||
HnswGraphSearcher<float[]> graphSearcher =
|
||||
new HnswGraphSearcher<>(
|
||||
vectorEncoding,
|
||||
|
@ -132,7 +120,21 @@ public class HnswGraphSearcher<T> {
|
|||
return results;
|
||||
}
|
||||
|
||||
private static NeighborQueue search(
|
||||
/**
|
||||
* Searches HNSW graph for the nearest neighbors of a query vector.
|
||||
*
|
||||
* @param query search query vector
|
||||
* @param topK the number of nodes to be returned
|
||||
* @param vectors the vector values
|
||||
* @param similarityFunction the similarity function to compare vectors
|
||||
* @param graph the graph values. May represent the entire graph, or a level in a hierarchical
|
||||
* graph.
|
||||
* @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or
|
||||
* {@code null} if they are all allowed to match.
|
||||
* @param visitedLimit the maximum number of nodes that the search is allowed to visit
|
||||
* @return a priority queue holding the closest neighbors found
|
||||
*/
|
||||
public static NeighborQueue search(
|
||||
BytesRef query,
|
||||
int topK,
|
||||
RandomAccessVectorValues vectors,
|
||||
|
@ -142,6 +144,13 @@ public class HnswGraphSearcher<T> {
|
|||
Bits acceptOrds,
|
||||
int visitedLimit)
|
||||
throws IOException {
|
||||
if (query.length != vectors.dimension()) {
|
||||
throw new IllegalArgumentException(
|
||||
"vector query dimension: "
|
||||
+ query.length
|
||||
+ " differs from field dimension: "
|
||||
+ vectors.dimension());
|
||||
}
|
||||
HnswGraphSearcher<BytesRef> graphSearcher =
|
||||
new HnswGraphSearcher<>(
|
||||
vectorEncoding,
|
||||
|
|
|
@ -33,6 +33,7 @@ import org.apache.lucene.store.Directory;
|
|||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.NamedThreadFactory;
|
||||
import org.apache.lucene.util.Version;
|
||||
|
@ -117,6 +118,12 @@ public class TestSegmentToThreadMapping extends LuceneTestCase {
|
|||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(
|
||||
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doClose() {}
|
||||
|
||||
|
|
|
@ -0,0 +1,887 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.search;
|
||||
|
||||
import static com.carrotsearch.randomizedtesting.RandomizedTest.frequently;
|
||||
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.IntPoint;
|
||||
import org.apache.lucene.document.NumericDocValuesField;
|
||||
import org.apache.lucene.document.StringField;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
import org.apache.lucene.index.FilterDirectoryReader;
|
||||
import org.apache.lucene.index.FilterLeafReader;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.StoredFields;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.tests.util.TestUtil;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.BitSetIterator;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
|
||||
/** Test cases for KnnVectorQuery objects. */
|
||||
abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
|
||||
|
||||
abstract AbstractKnnVectorQuery getKnnVectorQuery(
|
||||
String field, float[] query, int k, Query queryFilter);
|
||||
|
||||
abstract AbstractKnnVectorQuery getThrowingKnnVectorQuery(
|
||||
String field, float[] query, int k, Query queryFilter);
|
||||
|
||||
AbstractKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k) {
|
||||
return getKnnVectorQuery(field, query, k, null);
|
||||
}
|
||||
|
||||
abstract float[] randomVector(int dim);
|
||||
|
||||
abstract VectorEncoding vectorEncoding();
|
||||
|
||||
abstract Field getKnnVectorField(
|
||||
String name, float[] vector, VectorSimilarityFunction similarityFunction);
|
||||
|
||||
abstract Field getKnnVectorField(String name, float[] vector);
|
||||
|
||||
public void testEquals() {
|
||||
AbstractKnnVectorQuery q1 = getKnnVectorQuery("f1", new float[] {0, 1}, 10);
|
||||
Query filter1 = new TermQuery(new Term("id", "id1"));
|
||||
AbstractKnnVectorQuery q2 = getKnnVectorQuery("f1", new float[] {0, 1}, 10, filter1);
|
||||
|
||||
assertNotEquals(q2, q1);
|
||||
assertNotEquals(q1, q2);
|
||||
assertEquals(q2, getKnnVectorQuery("f1", new float[] {0, 1}, 10, filter1));
|
||||
|
||||
Query filter2 = new TermQuery(new Term("id", "id2"));
|
||||
assertNotEquals(q2, getKnnVectorQuery("f1", new float[] {0, 1}, 10, filter2));
|
||||
|
||||
assertEquals(q1, getKnnVectorQuery("f1", new float[] {0, 1}, 10));
|
||||
|
||||
assertNotEquals(null, q1);
|
||||
|
||||
assertNotEquals(q1, new TermQuery(new Term("f1", "x")));
|
||||
|
||||
assertNotEquals(q1, getKnnVectorQuery("f2", new float[] {0, 1}, 10));
|
||||
assertNotEquals(q1, getKnnVectorQuery("f1", new float[] {1, 1}, 10));
|
||||
assertNotEquals(q1, getKnnVectorQuery("f1", new float[] {0, 1}, 2));
|
||||
assertNotEquals(q1, getKnnVectorQuery("f1", new float[] {0}, 10));
|
||||
}
|
||||
|
||||
/**
|
||||
* Tests if a AbstractKnnVectorQuery is rewritten to a MatchNoDocsQuery when there are no
|
||||
* documents to match.
|
||||
*/
|
||||
public void testEmptyIndex() throws IOException {
|
||||
try (Directory indexStore = getIndexStore("field");
|
||||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
AbstractKnnVectorQuery kvq = getKnnVectorQuery("field", new float[] {1, 2}, 10);
|
||||
assertMatches(searcher, kvq, 0);
|
||||
Query q = searcher.rewrite(kvq);
|
||||
assertTrue(q instanceof MatchNoDocsQuery);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Tests that a AbstractKnnVectorQuery whose topK >= numDocs returns all the documents in score
|
||||
* order
|
||||
*/
|
||||
public void testFindAll() throws IOException {
|
||||
try (Directory indexStore =
|
||||
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
|
||||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
AbstractKnnVectorQuery kvq = getKnnVectorQuery("field", new float[] {0, 0}, 10);
|
||||
assertMatches(searcher, kvq, 3);
|
||||
ScoreDoc[] scoreDocs = searcher.search(kvq, 3).scoreDocs;
|
||||
assertIdMatches(reader, "id2", scoreDocs[0]);
|
||||
assertIdMatches(reader, "id0", scoreDocs[1]);
|
||||
assertIdMatches(reader, "id1", scoreDocs[2]);
|
||||
}
|
||||
}
|
||||
|
||||
public void testSearchBoost() throws IOException {
|
||||
try (Directory indexStore =
|
||||
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
|
||||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
|
||||
Query vectorQuery = getKnnVectorQuery("field", new float[] {0, 0}, 10);
|
||||
ScoreDoc[] scoreDocs = searcher.search(vectorQuery, 3).scoreDocs;
|
||||
|
||||
Query boostQuery = new BoostQuery(vectorQuery, 3.0f);
|
||||
ScoreDoc[] boostScoreDocs = searcher.search(boostQuery, 3).scoreDocs;
|
||||
assertEquals(scoreDocs.length, boostScoreDocs.length);
|
||||
|
||||
for (int i = 0; i < scoreDocs.length; i++) {
|
||||
ScoreDoc scoreDoc = scoreDocs[i];
|
||||
ScoreDoc boostScoreDoc = boostScoreDocs[i];
|
||||
|
||||
assertEquals(scoreDoc.doc, boostScoreDoc.doc);
|
||||
assertEquals(scoreDoc.score * 3.0f, boostScoreDoc.score, 0.001f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Tests that a AbstractKnnVectorQuery applies the filter query */
|
||||
public void testSimpleFilter() throws IOException {
|
||||
try (Directory indexStore =
|
||||
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
|
||||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
Query filter = new TermQuery(new Term("id", "id2"));
|
||||
Query kvq = getKnnVectorQuery("field", new float[] {0, 0}, 10, filter);
|
||||
TopDocs topDocs = searcher.search(kvq, 3);
|
||||
assertEquals(1, topDocs.totalHits.value);
|
||||
assertIdMatches(reader, "id2", topDocs.scoreDocs[0]);
|
||||
}
|
||||
}
|
||||
|
||||
public void testFilterWithNoVectorMatches() throws IOException {
|
||||
try (Directory indexStore =
|
||||
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
|
||||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
|
||||
Query filter = new TermQuery(new Term("other", "value"));
|
||||
Query kvq = getKnnVectorQuery("field", new float[] {0, 0}, 10, filter);
|
||||
TopDocs topDocs = searcher.search(kvq, 3);
|
||||
assertEquals(0, topDocs.totalHits.value);
|
||||
}
|
||||
}
|
||||
|
||||
/** testDimensionMismatch */
|
||||
public void testDimensionMismatch() throws IOException {
|
||||
try (Directory indexStore =
|
||||
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
|
||||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
AbstractKnnVectorQuery kvq = getKnnVectorQuery("field", new float[] {0}, 10);
|
||||
IllegalArgumentException e =
|
||||
expectThrows(IllegalArgumentException.class, () -> searcher.search(kvq, 10));
|
||||
assertEquals("vector query dimension: 1 differs from field dimension: 2", e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
/** testNonVectorField */
|
||||
public void testNonVectorField() throws IOException {
|
||||
try (Directory indexStore =
|
||||
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
|
||||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
assertMatches(searcher, getKnnVectorQuery("xyzzy", new float[] {0}, 10), 0);
|
||||
assertMatches(searcher, getKnnVectorQuery("id", new float[] {0}, 10), 0);
|
||||
}
|
||||
}
|
||||
|
||||
/** Test bad parameters */
|
||||
public void testIllegalArguments() throws IOException {
|
||||
expectThrows(IllegalArgumentException.class, () -> getKnnVectorQuery("xx", new float[] {1}, 0));
|
||||
}
|
||||
|
||||
public void testDifferentReader() throws IOException {
|
||||
try (Directory indexStore =
|
||||
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
|
||||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {2, 3}, 3);
|
||||
Query dasq = query.rewrite(newSearcher(reader));
|
||||
IndexSearcher leafSearcher = newSearcher(reader.leaves().get(0).reader());
|
||||
expectThrows(
|
||||
IllegalStateException.class,
|
||||
() -> dasq.createWeight(leafSearcher, ScoreMode.COMPLETE, 1));
|
||||
}
|
||||
}
|
||||
|
||||
public void testAdvanceShallow() throws IOException {
|
||||
try (Directory d = newDirectory()) {
|
||||
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
|
||||
for (int j = 0; j < 5; j++) {
|
||||
Document doc = new Document();
|
||||
doc.add(getKnnVectorField("field", new float[] {j, j}));
|
||||
w.addDocument(doc);
|
||||
}
|
||||
}
|
||||
try (IndexReader reader = DirectoryReader.open(d)) {
|
||||
IndexSearcher searcher = new IndexSearcher(reader);
|
||||
AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {2, 3}, 3);
|
||||
Query dasq = query.rewrite(searcher);
|
||||
Scorer scorer =
|
||||
dasq.createWeight(searcher, ScoreMode.COMPLETE, 1).scorer(reader.leaves().get(0));
|
||||
// before advancing the iterator
|
||||
assertEquals(1, scorer.advanceShallow(0));
|
||||
assertEquals(1, scorer.advanceShallow(1));
|
||||
assertEquals(NO_MORE_DOCS, scorer.advanceShallow(10));
|
||||
|
||||
// after advancing the iterator
|
||||
scorer.iterator().advance(2);
|
||||
assertEquals(2, scorer.advanceShallow(0));
|
||||
assertEquals(2, scorer.advanceShallow(2));
|
||||
assertEquals(3, scorer.advanceShallow(3));
|
||||
assertEquals(NO_MORE_DOCS, scorer.advanceShallow(10));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testScoreEuclidean() throws IOException {
|
||||
float[][] vectors = new float[5][];
|
||||
for (int j = 0; j < 5; j++) {
|
||||
vectors[j] = new float[] {j, j};
|
||||
}
|
||||
try (Directory d = getStableIndexStore("field", vectors);
|
||||
IndexReader reader = DirectoryReader.open(d)) {
|
||||
IndexSearcher searcher = new IndexSearcher(reader);
|
||||
AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {2, 3}, 3);
|
||||
Query rewritten = query.rewrite(searcher);
|
||||
Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1);
|
||||
Scorer scorer = weight.scorer(reader.leaves().get(0));
|
||||
|
||||
// prior to advancing, score is 0
|
||||
assertEquals(-1, scorer.docID());
|
||||
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
|
||||
|
||||
// test getMaxScore
|
||||
assertEquals(0, scorer.getMaxScore(-1), 0);
|
||||
assertEquals(0, scorer.getMaxScore(0), 0);
|
||||
// This is 1 / ((l2distance((2,3), (2, 2)) = 1) + 1) = 0.5
|
||||
assertEquals(1 / 2f, scorer.getMaxScore(2), 0);
|
||||
assertEquals(1 / 2f, scorer.getMaxScore(Integer.MAX_VALUE), 0);
|
||||
|
||||
DocIdSetIterator it = scorer.iterator();
|
||||
assertEquals(3, it.cost());
|
||||
assertEquals(1, it.nextDoc());
|
||||
assertEquals(1 / 6f, scorer.score(), 0);
|
||||
assertEquals(3, it.advance(3));
|
||||
assertEquals(1 / 2f, scorer.score(), 0);
|
||||
assertEquals(NO_MORE_DOCS, it.advance(4));
|
||||
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
|
||||
}
|
||||
}
|
||||
|
||||
public void testScoreCosine() throws IOException {
|
||||
try (Directory d = newDirectory()) {
|
||||
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
|
||||
for (int j = 1; j <= 5; j++) {
|
||||
Document doc = new Document();
|
||||
doc.add(getKnnVectorField("field", new float[] {j, j * j}, COSINE));
|
||||
w.addDocument(doc);
|
||||
}
|
||||
}
|
||||
try (IndexReader reader = DirectoryReader.open(d)) {
|
||||
assertEquals(1, reader.leaves().size());
|
||||
IndexSearcher searcher = new IndexSearcher(reader);
|
||||
AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {2, 3}, 3);
|
||||
Query rewritten = query.rewrite(searcher);
|
||||
Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1);
|
||||
Scorer scorer = weight.scorer(reader.leaves().get(0));
|
||||
|
||||
// prior to advancing, score is undefined
|
||||
assertEquals(-1, scorer.docID());
|
||||
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
|
||||
|
||||
// test getMaxScore
|
||||
assertEquals(0, scorer.getMaxScore(-1), 0);
|
||||
/* maxAtZero = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)), then
|
||||
* normalized by (1 + x) /2.
|
||||
*/
|
||||
float maxAtZero =
|
||||
(float) ((1 + (2 * 1 + 3 * 1) / Math.sqrt((2 * 2 + 3 * 3) * (1 * 1 + 1 * 1))) / 2);
|
||||
assertEquals(maxAtZero, scorer.getMaxScore(0), 0.001);
|
||||
|
||||
/* max at 2 is actually the score for doc 1 which is the highest (since doc 1 vector (2, 4)
|
||||
* is the closest to (2, 3)). This is ((2,3) * (2, 4) = 16) / (||2, 3|| * ||2, 4|| = sqrt(260)), then
|
||||
* normalized by (1 + x) /2
|
||||
*/
|
||||
float expected =
|
||||
(float) ((1 + (2 * 2 + 3 * 4) / Math.sqrt((2 * 2 + 3 * 3) * (2 * 2 + 4 * 4))) / 2);
|
||||
assertEquals(expected, scorer.getMaxScore(2), 0);
|
||||
assertEquals(expected, scorer.getMaxScore(Integer.MAX_VALUE), 0);
|
||||
|
||||
DocIdSetIterator it = scorer.iterator();
|
||||
assertEquals(3, it.cost());
|
||||
assertEquals(0, it.nextDoc());
|
||||
// doc 0 has (1, 1)
|
||||
assertEquals(maxAtZero, scorer.score(), 0.0001);
|
||||
assertEquals(1, it.advance(1));
|
||||
assertEquals(expected, scorer.score(), 0);
|
||||
assertEquals(2, it.nextDoc());
|
||||
// since topK was 3
|
||||
assertEquals(NO_MORE_DOCS, it.advance(4));
|
||||
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testExplain() throws IOException {
|
||||
try (Directory d = newDirectory()) {
|
||||
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
|
||||
for (int j = 0; j < 5; j++) {
|
||||
Document doc = new Document();
|
||||
doc.add(getKnnVectorField("field", new float[] {j, j}));
|
||||
w.addDocument(doc);
|
||||
}
|
||||
}
|
||||
try (IndexReader reader = DirectoryReader.open(d)) {
|
||||
IndexSearcher searcher = new IndexSearcher(reader);
|
||||
AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {2, 3}, 3);
|
||||
Explanation matched = searcher.explain(query, 2);
|
||||
assertTrue(matched.isMatch());
|
||||
assertEquals(1 / 2f, matched.getValue());
|
||||
assertEquals(0, matched.getDetails().length);
|
||||
assertEquals("within top 3", matched.getDescription());
|
||||
|
||||
Explanation nomatch = searcher.explain(query, 4);
|
||||
assertFalse(nomatch.isMatch());
|
||||
assertEquals(0f, nomatch.getValue());
|
||||
assertEquals(0, matched.getDetails().length);
|
||||
assertEquals("not in top 3", nomatch.getDescription());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testExplainMultipleSegments() throws IOException {
|
||||
try (Directory d = newDirectory()) {
|
||||
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
|
||||
for (int j = 0; j < 5; j++) {
|
||||
Document doc = new Document();
|
||||
doc.add(getKnnVectorField("field", new float[] {j, j}));
|
||||
w.addDocument(doc);
|
||||
w.commit();
|
||||
}
|
||||
}
|
||||
try (IndexReader reader = DirectoryReader.open(d)) {
|
||||
IndexSearcher searcher = new IndexSearcher(reader);
|
||||
AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {2, 3}, 3);
|
||||
Explanation matched = searcher.explain(query, 2);
|
||||
assertTrue(matched.isMatch());
|
||||
assertEquals(1 / 2f, matched.getValue());
|
||||
assertEquals(0, matched.getDetails().length);
|
||||
assertEquals("within top 3", matched.getDescription());
|
||||
|
||||
Explanation nomatch = searcher.explain(query, 4);
|
||||
assertFalse(nomatch.isMatch());
|
||||
assertEquals(0f, nomatch.getValue());
|
||||
assertEquals(0, matched.getDetails().length);
|
||||
assertEquals("not in top 3", nomatch.getDescription());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Test that when vectors are abnormally distributed among segments, we still find the top K */
|
||||
public void testSkewedIndex() throws IOException {
|
||||
/* We have to choose the numbers carefully here so that some segment has more than the expected
|
||||
* number of top K documents, but no more than K documents in total (otherwise we might occasionally
|
||||
* randomly fail to find one).
|
||||
*/
|
||||
try (Directory d = newDirectory()) {
|
||||
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
|
||||
int r = 0;
|
||||
for (int i = 0; i < 5; i++) {
|
||||
for (int j = 0; j < 5; j++) {
|
||||
Document doc = new Document();
|
||||
doc.add(getKnnVectorField("field", new float[] {r, r}));
|
||||
doc.add(new StringField("id", "id" + r, Field.Store.YES));
|
||||
w.addDocument(doc);
|
||||
++r;
|
||||
}
|
||||
w.flush();
|
||||
}
|
||||
}
|
||||
try (IndexReader reader = DirectoryReader.open(d)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
TopDocs results = searcher.search(getKnnVectorQuery("field", new float[] {0, 0}, 8), 10);
|
||||
assertEquals(8, results.scoreDocs.length);
|
||||
assertIdMatches(reader, "id0", results.scoreDocs[0]);
|
||||
assertIdMatches(reader, "id7", results.scoreDocs[7]);
|
||||
|
||||
// test some results in the middle of the sequence - also tests docid tiebreaking
|
||||
results = searcher.search(getKnnVectorQuery("field", new float[] {10, 10}, 8), 10);
|
||||
assertEquals(8, results.scoreDocs.length);
|
||||
assertIdMatches(reader, "id10", results.scoreDocs[0]);
|
||||
assertIdMatches(reader, "id6", results.scoreDocs[7]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Tests with random vectors, number of documents, etc. Uses RandomIndexWriter. */
|
||||
public void testRandom() throws IOException {
|
||||
int numDocs = atLeast(100);
|
||||
int dimension = atLeast(5);
|
||||
int numIters = atLeast(10);
|
||||
boolean everyDocHasAVector = random().nextBoolean();
|
||||
try (Directory d = newDirectory()) {
|
||||
RandomIndexWriter w = new RandomIndexWriter(random(), d);
|
||||
for (int i = 0; i < numDocs; i++) {
|
||||
Document doc = new Document();
|
||||
if (everyDocHasAVector || random().nextInt(10) != 2) {
|
||||
doc.add(getKnnVectorField("field", randomVector(dimension)));
|
||||
}
|
||||
w.addDocument(doc);
|
||||
}
|
||||
w.close();
|
||||
try (IndexReader reader = DirectoryReader.open(d)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
for (int i = 0; i < numIters; i++) {
|
||||
int k = random().nextInt(80) + 1;
|
||||
AbstractKnnVectorQuery query = getKnnVectorQuery("field", randomVector(dimension), k);
|
||||
int n = random().nextInt(100) + 1;
|
||||
TopDocs results = searcher.search(query, n);
|
||||
int expected = Math.min(Math.min(n, k), reader.numDocs());
|
||||
// we may get fewer results than requested if there are deletions, but this test doesn't
|
||||
// test that
|
||||
assert reader.hasDeletions() == false;
|
||||
assertEquals(expected, results.scoreDocs.length);
|
||||
assertTrue(results.totalHits.value >= results.scoreDocs.length);
|
||||
// verify the results are in descending score order
|
||||
float last = Float.MAX_VALUE;
|
||||
for (ScoreDoc scoreDoc : results.scoreDocs) {
|
||||
assertTrue(scoreDoc.score <= last);
|
||||
last = scoreDoc.score;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Tests with random vectors and a random filter. Uses RandomIndexWriter. */
|
||||
public void testRandomWithFilter() throws IOException {
|
||||
int numDocs = 1000;
|
||||
int dimension = atLeast(5);
|
||||
int numIters = atLeast(10);
|
||||
try (Directory d = newDirectory()) {
|
||||
// Always use the default kNN format to have predictable behavior around when it hits
|
||||
// visitedLimit. This is fine since the test targets AbstractKnnVectorQuery logic, not the kNN
|
||||
// format
|
||||
// implementation.
|
||||
IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec());
|
||||
RandomIndexWriter w = new RandomIndexWriter(random(), d, iwc);
|
||||
for (int i = 0; i < numDocs; i++) {
|
||||
Document doc = new Document();
|
||||
doc.add(getKnnVectorField("field", randomVector(dimension)));
|
||||
doc.add(new NumericDocValuesField("tag", i));
|
||||
doc.add(new IntPoint("tag", i));
|
||||
w.addDocument(doc);
|
||||
}
|
||||
w.forceMerge(1);
|
||||
w.close();
|
||||
|
||||
try (DirectoryReader reader = DirectoryReader.open(d)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
for (int i = 0; i < numIters; i++) {
|
||||
int lower = random().nextInt(500);
|
||||
|
||||
// Test a filter with cost less than k and check we use exact search
|
||||
Query filter1 = IntPoint.newRangeQuery("tag", lower, lower + 8);
|
||||
TopDocs results =
|
||||
searcher.search(
|
||||
getKnnVectorQuery("field", randomVector(dimension), 10, filter1), numDocs);
|
||||
assertEquals(9, results.totalHits.value);
|
||||
assertEquals(results.totalHits.value, results.scoreDocs.length);
|
||||
expectThrows(
|
||||
UnsupportedOperationException.class,
|
||||
() ->
|
||||
searcher.search(
|
||||
getThrowingKnnVectorQuery("field", randomVector(dimension), 10, filter1),
|
||||
numDocs));
|
||||
|
||||
// Test a restrictive filter and check we use exact search
|
||||
Query filter2 = IntPoint.newRangeQuery("tag", lower, lower + 6);
|
||||
results =
|
||||
searcher.search(
|
||||
getKnnVectorQuery("field", randomVector(dimension), 5, filter2), numDocs);
|
||||
assertEquals(5, results.totalHits.value);
|
||||
assertEquals(results.totalHits.value, results.scoreDocs.length);
|
||||
expectThrows(
|
||||
UnsupportedOperationException.class,
|
||||
() ->
|
||||
searcher.search(
|
||||
getThrowingKnnVectorQuery("field", randomVector(dimension), 5, filter2),
|
||||
numDocs));
|
||||
|
||||
// Test an unrestrictive filter and check we use approximate search
|
||||
Query filter3 = IntPoint.newRangeQuery("tag", lower, numDocs);
|
||||
results =
|
||||
searcher.search(
|
||||
getThrowingKnnVectorQuery("field", randomVector(dimension), 5, filter3),
|
||||
numDocs,
|
||||
new Sort(new SortField("tag", SortField.Type.INT)));
|
||||
assertEquals(5, results.totalHits.value);
|
||||
assertEquals(results.totalHits.value, results.scoreDocs.length);
|
||||
|
||||
for (ScoreDoc scoreDoc : results.scoreDocs) {
|
||||
FieldDoc fieldDoc = (FieldDoc) scoreDoc;
|
||||
assertEquals(1, fieldDoc.fields.length);
|
||||
|
||||
int tag = (int) fieldDoc.fields[0];
|
||||
assertTrue(lower <= tag && tag <= numDocs);
|
||||
}
|
||||
|
||||
// Test a filter that exhausts visitedLimit in upper levels, and switches to exact search
|
||||
Query filter4 = IntPoint.newRangeQuery("tag", lower, lower + 2);
|
||||
expectThrows(
|
||||
UnsupportedOperationException.class,
|
||||
() ->
|
||||
searcher.search(
|
||||
getThrowingKnnVectorQuery("field", randomVector(dimension), 1, filter4),
|
||||
numDocs));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Tests filtering when all vectors have the same score. */
|
||||
@AwaitsFix(bugUrl = "https://github.com/apache/lucene/issues/11787")
|
||||
public void testFilterWithSameScore() throws IOException {
|
||||
int numDocs = 100;
|
||||
int dimension = atLeast(5);
|
||||
try (Directory d = newDirectory()) {
|
||||
// Always use the default kNN format to have predictable behavior around when it hits
|
||||
// visitedLimit. This is fine since the test targets AbstractKnnVectorQuery logic, not the kNN
|
||||
// format
|
||||
// implementation.
|
||||
IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec());
|
||||
IndexWriter w = new IndexWriter(d, iwc);
|
||||
float[] vector = randomVector(dimension);
|
||||
for (int i = 0; i < numDocs; i++) {
|
||||
Document doc = new Document();
|
||||
doc.add(getKnnVectorField("field", vector));
|
||||
doc.add(new IntPoint("tag", i));
|
||||
w.addDocument(doc);
|
||||
}
|
||||
w.forceMerge(1);
|
||||
w.close();
|
||||
|
||||
try (DirectoryReader reader = DirectoryReader.open(d)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
int lower = random().nextInt(50);
|
||||
int size = 5;
|
||||
|
||||
// Test a restrictive filter, which usually performs exact search
|
||||
Query filter1 = IntPoint.newRangeQuery("tag", lower, lower + 6);
|
||||
TopDocs results =
|
||||
searcher.search(
|
||||
getKnnVectorQuery("field", randomVector(dimension), size, filter1), size);
|
||||
assertEquals(size, results.scoreDocs.length);
|
||||
|
||||
// Test an unrestrictive filter, which usually performs approximate search
|
||||
Query filter2 = IntPoint.newRangeQuery("tag", lower, numDocs);
|
||||
results =
|
||||
searcher.search(
|
||||
getKnnVectorQuery("field", randomVector(dimension), size, filter2), size);
|
||||
assertEquals(size, results.scoreDocs.length);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testDeletes() throws IOException {
|
||||
try (Directory dir = newDirectory();
|
||||
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
|
||||
final int numDocs = atLeast(100);
|
||||
final int dim = 30;
|
||||
for (int i = 0; i < numDocs; ++i) {
|
||||
Document d = new Document();
|
||||
d.add(new StringField("index", String.valueOf(i), Field.Store.YES));
|
||||
if (frequently()) {
|
||||
d.add(getKnnVectorField("vector", randomVector(dim)));
|
||||
}
|
||||
w.addDocument(d);
|
||||
}
|
||||
w.commit();
|
||||
|
||||
// Delete some documents at random, both those with and without vectors
|
||||
Set<Term> toDelete = new HashSet<>();
|
||||
for (int i = 0; i < 25; i++) {
|
||||
int index = random().nextInt(numDocs);
|
||||
toDelete.add(new Term("index", String.valueOf(index)));
|
||||
}
|
||||
w.deleteDocuments(toDelete.toArray(new Term[0]));
|
||||
w.commit();
|
||||
|
||||
int hits = 50;
|
||||
try (IndexReader reader = DirectoryReader.open(dir)) {
|
||||
Set<String> allIds = new HashSet<>();
|
||||
IndexSearcher searcher = new IndexSearcher(reader);
|
||||
AbstractKnnVectorQuery query = getKnnVectorQuery("vector", randomVector(dim), hits);
|
||||
TopDocs topDocs = searcher.search(query, numDocs);
|
||||
StoredFields storedFields = reader.storedFields();
|
||||
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
|
||||
Document doc = storedFields.document(scoreDoc.doc, Set.of("index"));
|
||||
String index = doc.get("index");
|
||||
assertFalse(
|
||||
"search returned a deleted document: " + index,
|
||||
toDelete.contains(new Term("index", index)));
|
||||
allIds.add(index);
|
||||
}
|
||||
assertEquals("search missed some documents", hits, allIds.size());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testAllDeletes() throws IOException {
|
||||
try (Directory dir = newDirectory();
|
||||
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
|
||||
final int numDocs = atLeast(100);
|
||||
final int dim = 30;
|
||||
for (int i = 0; i < numDocs; ++i) {
|
||||
Document d = new Document();
|
||||
d.add(getKnnVectorField("vector", randomVector(dim)));
|
||||
w.addDocument(d);
|
||||
}
|
||||
w.commit();
|
||||
|
||||
w.deleteDocuments(new MatchAllDocsQuery());
|
||||
w.commit();
|
||||
|
||||
try (IndexReader reader = DirectoryReader.open(dir)) {
|
||||
IndexSearcher searcher = new IndexSearcher(reader);
|
||||
AbstractKnnVectorQuery query = getKnnVectorQuery("vector", randomVector(dim), numDocs);
|
||||
TopDocs topDocs = searcher.search(query, numDocs);
|
||||
assertEquals(0, topDocs.scoreDocs.length);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check that the query behaves reasonably when using a custom filter reader where there are no
|
||||
* live docs.
|
||||
*/
|
||||
public void testNoLiveDocsReader() throws IOException {
|
||||
IndexWriterConfig iwc = newIndexWriterConfig();
|
||||
try (Directory dir = newDirectory();
|
||||
IndexWriter w = new IndexWriter(dir, iwc)) {
|
||||
final int numDocs = 10;
|
||||
final int dim = 30;
|
||||
for (int i = 0; i < numDocs; ++i) {
|
||||
Document d = new Document();
|
||||
d.add(new StringField("index", String.valueOf(i), Field.Store.NO));
|
||||
d.add(getKnnVectorField("vector", randomVector(dim)));
|
||||
w.addDocument(d);
|
||||
}
|
||||
w.commit();
|
||||
|
||||
try (DirectoryReader reader = DirectoryReader.open(dir)) {
|
||||
DirectoryReader wrappedReader = new NoLiveDocsDirectoryReader(reader);
|
||||
IndexSearcher searcher = new IndexSearcher(wrappedReader);
|
||||
AbstractKnnVectorQuery query = getKnnVectorQuery("vector", randomVector(dim), numDocs);
|
||||
TopDocs topDocs = searcher.search(query, numDocs);
|
||||
assertEquals(0, topDocs.scoreDocs.length);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Test that AbstractKnnVectorQuery optimizes the case where the filter query is backed by {@link
|
||||
* BitSetIterator}.
|
||||
*/
|
||||
public void testBitSetQuery() throws IOException {
|
||||
IndexWriterConfig iwc = newIndexWriterConfig();
|
||||
try (Directory dir = newDirectory();
|
||||
IndexWriter w = new IndexWriter(dir, iwc)) {
|
||||
final int numDocs = 100;
|
||||
final int dim = 30;
|
||||
for (int i = 0; i < numDocs; ++i) {
|
||||
Document d = new Document();
|
||||
d.add(getKnnVectorField("vector", randomVector(dim)));
|
||||
w.addDocument(d);
|
||||
}
|
||||
w.commit();
|
||||
|
||||
try (DirectoryReader reader = DirectoryReader.open(dir)) {
|
||||
IndexSearcher searcher = new IndexSearcher(reader);
|
||||
|
||||
Query filter = new ThrowingBitSetQuery(new FixedBitSet(numDocs));
|
||||
expectThrows(
|
||||
UnsupportedOperationException.class,
|
||||
() ->
|
||||
searcher.search(
|
||||
getKnnVectorQuery("vector", randomVector(dim), 10, filter), numDocs));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Creates a new directory and adds documents with the given vectors as kNN vector fields */
|
||||
Directory getIndexStore(String field, float[]... contents) throws IOException {
|
||||
Directory indexStore = newDirectory();
|
||||
RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore);
|
||||
for (int i = 0; i < contents.length; ++i) {
|
||||
Document doc = new Document();
|
||||
doc.add(getKnnVectorField(field, contents[i]));
|
||||
doc.add(new StringField("id", "id" + i, Field.Store.YES));
|
||||
writer.addDocument(doc);
|
||||
}
|
||||
// Add some documents without a vector
|
||||
for (int i = 0; i < 5; i++) {
|
||||
Document doc = new Document();
|
||||
doc.add(new StringField("other", "value", Field.Store.NO));
|
||||
writer.addDocument(doc);
|
||||
}
|
||||
writer.close();
|
||||
return indexStore;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new directory and adds documents with the given vectors as kNN vector fields,
|
||||
* preserving the order of the added documents.
|
||||
*/
|
||||
private Directory getStableIndexStore(String field, float[]... contents) throws IOException {
|
||||
Directory indexStore = newDirectory();
|
||||
try (IndexWriter writer = new IndexWriter(indexStore, new IndexWriterConfig())) {
|
||||
for (int i = 0; i < contents.length; ++i) {
|
||||
Document doc = new Document();
|
||||
doc.add(getKnnVectorField(field, contents[i]));
|
||||
doc.add(new StringField("id", "id" + i, Field.Store.YES));
|
||||
writer.addDocument(doc);
|
||||
}
|
||||
// Add some documents without a vector
|
||||
for (int i = 0; i < 5; i++) {
|
||||
Document doc = new Document();
|
||||
doc.add(new StringField("other", "value", Field.Store.NO));
|
||||
writer.addDocument(doc);
|
||||
}
|
||||
}
|
||||
return indexStore;
|
||||
}
|
||||
|
||||
private void assertMatches(IndexSearcher searcher, Query q, int expectedMatches)
|
||||
throws IOException {
|
||||
ScoreDoc[] result = searcher.search(q, 1000).scoreDocs;
|
||||
assertEquals(expectedMatches, result.length);
|
||||
}
|
||||
|
||||
void assertIdMatches(IndexReader reader, String expectedId, ScoreDoc scoreDoc)
|
||||
throws IOException {
|
||||
String actualId = reader.storedFields().document(scoreDoc.doc).get("id");
|
||||
assertEquals(expectedId, actualId);
|
||||
}
|
||||
|
||||
/**
|
||||
* A version of {@link AbstractKnnVectorQuery} that throws an error when an exact search is run.
|
||||
* This allows us to check what search strategy is being used.
|
||||
*/
|
||||
private static class NoLiveDocsDirectoryReader extends FilterDirectoryReader {
|
||||
|
||||
private NoLiveDocsDirectoryReader(DirectoryReader in) throws IOException {
|
||||
super(
|
||||
in,
|
||||
new SubReaderWrapper() {
|
||||
@Override
|
||||
public LeafReader wrap(LeafReader reader) {
|
||||
return new NoLiveDocsLeafReader(reader);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
protected DirectoryReader doWrapDirectoryReader(DirectoryReader in) throws IOException {
|
||||
return new NoLiveDocsDirectoryReader(in);
|
||||
}
|
||||
|
||||
@Override
|
||||
public CacheHelper getReaderCacheHelper() {
|
||||
return in.getReaderCacheHelper();
|
||||
}
|
||||
}
|
||||
|
||||
private static class NoLiveDocsLeafReader extends FilterLeafReader {
|
||||
private NoLiveDocsLeafReader(LeafReader in) {
|
||||
super(in);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numDocs() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Bits getLiveDocs() {
|
||||
return new Bits.MatchNoBits(in.maxDoc());
|
||||
}
|
||||
|
||||
@Override
|
||||
public CacheHelper getReaderCacheHelper() {
|
||||
return in.getReaderCacheHelper();
|
||||
}
|
||||
|
||||
@Override
|
||||
public CacheHelper getCoreCacheHelper() {
|
||||
return in.getCoreCacheHelper();
|
||||
}
|
||||
}
|
||||
|
||||
static class ThrowingBitSetQuery extends Query {
|
||||
|
||||
private final FixedBitSet docs;
|
||||
|
||||
ThrowingBitSetQuery(FixedBitSet docs) {
|
||||
this.docs = docs;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
|
||||
throws IOException {
|
||||
return new ConstantScoreWeight(this, boost) {
|
||||
@Override
|
||||
public Scorer scorer(LeafReaderContext context) throws IOException {
|
||||
BitSetIterator bitSetIterator =
|
||||
new BitSetIterator(docs, docs.approximateCardinality()) {
|
||||
@Override
|
||||
public BitSet getBitSet() {
|
||||
throw new UnsupportedOperationException("reusing BitSet is not supported");
|
||||
}
|
||||
};
|
||||
return new ConstantScoreScorer(this, score(), scoreMode, bitSetIterator);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isCacheable(LeafReaderContext ctx) {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(QueryVisitor visitor) {}
|
||||
|
||||
@Override
|
||||
public String toString(String field) {
|
||||
return "throwingBitSetQuery";
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object other) {
|
||||
return sameClassAs(other) && docs.equals(((ThrowingBitSetQuery) other).docs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return 31 * classHash() + docs.hashCode();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,97 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.search;
|
||||
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.KnnVectorField;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.TestVectorUtil;
|
||||
|
||||
public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
|
||||
@Override
|
||||
AbstractKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) {
|
||||
return new KnnByteVectorQuery(field, floatToBytes(query), k, queryFilter);
|
||||
}
|
||||
|
||||
@Override
|
||||
AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) {
|
||||
return new ThrowingKnnVectorQuery(field, floatToBytes(vec), k, query);
|
||||
}
|
||||
|
||||
@Override
|
||||
float[] randomVector(int dim) {
|
||||
BytesRef bytesRef = TestVectorUtil.randomVectorBytes(dim);
|
||||
float[] v = new float[bytesRef.length];
|
||||
int vi = 0;
|
||||
for (int i = bytesRef.offset; i < v.length; i++) {
|
||||
v[vi++] = bytesRef.bytes[i];
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
@Override
|
||||
Field getKnnVectorField(
|
||||
String name, float[] vector, VectorSimilarityFunction similarityFunction) {
|
||||
return new KnnVectorField(name, new BytesRef(floatToBytes(vector)), similarityFunction);
|
||||
}
|
||||
|
||||
@Override
|
||||
Field getKnnVectorField(String name, float[] vector) {
|
||||
return new KnnVectorField(
|
||||
name, new BytesRef(floatToBytes(vector)), VectorSimilarityFunction.EUCLIDEAN);
|
||||
}
|
||||
|
||||
private static byte[] floatToBytes(float[] query) {
|
||||
byte[] bytes = new byte[query.length];
|
||||
for (int i = 0; i < query.length; i++) {
|
||||
assert query[i] <= Byte.MAX_VALUE && query[i] >= Byte.MIN_VALUE && (query[i] % 1) == 0
|
||||
: "float value cannot be converted to byte; provided: " + query[i];
|
||||
bytes[i] = (byte) query[i];
|
||||
}
|
||||
return bytes;
|
||||
}
|
||||
|
||||
public void testToString() {
|
||||
AbstractKnnVectorQuery q1 = getKnnVectorQuery("f1", new float[] {0, 1}, 10);
|
||||
assertEquals("KnnByteVectorQuery:f1[0,...][10]", q1.toString("ignored"));
|
||||
}
|
||||
|
||||
@Override
|
||||
VectorEncoding vectorEncoding() {
|
||||
return VectorEncoding.BYTE;
|
||||
}
|
||||
|
||||
private static class ThrowingKnnVectorQuery extends KnnByteVectorQuery {
|
||||
|
||||
public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter) {
|
||||
super(field, target, k, filter);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) {
|
||||
throw new UnsupportedOperationException("exact search is not supported");
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString(String field) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -172,6 +172,17 @@ public class TestVectorUtil extends LuceneTestCase {
|
|||
return v;
|
||||
}
|
||||
|
||||
public static BytesRef randomVectorBytes(int dim) {
|
||||
BytesRef v = TestUtil.randomBinaryTerm(random(), dim);
|
||||
// clip at -127 to avoid overflow
|
||||
for (int i = v.offset; i < v.offset + v.length; i++) {
|
||||
if (v.bytes[i] == -128) {
|
||||
v.bytes[i] = -127;
|
||||
}
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
public void testBasicDotProductBytes() {
|
||||
BytesRef a = new BytesRef(new byte[] {1, 2, 3});
|
||||
BytesRef b = new BytesRef(new byte[] {-10, 0, 5});
|
||||
|
|
|
@ -282,15 +282,26 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
|
||||
// run some searches
|
||||
NeighborQueue nn =
|
||||
HnswGraphSearcher.search(
|
||||
getTargetVector(),
|
||||
10,
|
||||
vectors.copy(),
|
||||
vectorEncoding,
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
null,
|
||||
Integer.MAX_VALUE);
|
||||
switch (vectorEncoding) {
|
||||
case BYTE -> HnswGraphSearcher.search(
|
||||
getTargetByteVector(),
|
||||
10,
|
||||
vectors.copy(),
|
||||
vectorEncoding,
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
null,
|
||||
Integer.MAX_VALUE);
|
||||
case FLOAT32 -> HnswGraphSearcher.search(
|
||||
getTargetVector(),
|
||||
10,
|
||||
vectors.copy(),
|
||||
vectorEncoding,
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
null,
|
||||
Integer.MAX_VALUE);
|
||||
};
|
||||
|
||||
int[] nodes = nn.nodes();
|
||||
assertEquals("Number of found results is not equal to [10].", 10, nodes.length);
|
||||
|
@ -324,15 +335,26 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
// the first 10 docs must not be deleted to ensure the expected recall
|
||||
Bits acceptOrds = createRandomAcceptOrds(10, vectors.size);
|
||||
NeighborQueue nn =
|
||||
HnswGraphSearcher.search(
|
||||
getTargetVector(),
|
||||
10,
|
||||
vectors.copy(),
|
||||
vectorEncoding,
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
acceptOrds,
|
||||
Integer.MAX_VALUE);
|
||||
switch (vectorEncoding) {
|
||||
case BYTE -> HnswGraphSearcher.search(
|
||||
getTargetByteVector(),
|
||||
10,
|
||||
vectors.copy(),
|
||||
vectorEncoding,
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
acceptOrds,
|
||||
Integer.MAX_VALUE);
|
||||
case FLOAT32 -> HnswGraphSearcher.search(
|
||||
getTargetVector(),
|
||||
10,
|
||||
vectors.copy(),
|
||||
vectorEncoding,
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
acceptOrds,
|
||||
Integer.MAX_VALUE);
|
||||
};
|
||||
int[] nodes = nn.nodes();
|
||||
assertEquals("Number of found results is not equal to [10].", 10, nodes.length);
|
||||
int sum = 0;
|
||||
|
@ -363,15 +385,27 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
// Check the search finds all accepted vectors
|
||||
int numAccepted = acceptOrds.cardinality();
|
||||
NeighborQueue nn =
|
||||
HnswGraphSearcher.search(
|
||||
getTargetVector(),
|
||||
numAccepted,
|
||||
vectors.copy(),
|
||||
vectorEncoding,
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
acceptOrds,
|
||||
Integer.MAX_VALUE);
|
||||
switch (vectorEncoding) {
|
||||
case FLOAT32 -> HnswGraphSearcher.search(
|
||||
getTargetVector(),
|
||||
numAccepted,
|
||||
vectors.copy(),
|
||||
vectorEncoding,
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
acceptOrds,
|
||||
Integer.MAX_VALUE);
|
||||
case BYTE -> HnswGraphSearcher.search(
|
||||
getTargetByteVector(),
|
||||
numAccepted,
|
||||
vectors.copy(),
|
||||
vectorEncoding,
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
acceptOrds,
|
||||
Integer.MAX_VALUE);
|
||||
};
|
||||
|
||||
int[] nodes = nn.nodes();
|
||||
assertEquals(numAccepted, nodes.length);
|
||||
for (int node : nodes) {
|
||||
|
@ -383,6 +417,10 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
return new float[] {1, 0};
|
||||
}
|
||||
|
||||
private BytesRef getTargetByteVector() {
|
||||
return new BytesRef(new byte[] {1, 0});
|
||||
}
|
||||
|
||||
public void testSearchWithSkewedAcceptOrds() throws IOException {
|
||||
int nDoc = 1000;
|
||||
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
||||
|
@ -432,15 +470,27 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
int topK = 50;
|
||||
int visitedLimit = topK + random().nextInt(5);
|
||||
NeighborQueue nn =
|
||||
HnswGraphSearcher.search(
|
||||
getTargetVector(),
|
||||
topK,
|
||||
vectors.copy(),
|
||||
vectorEncoding,
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
createRandomAcceptOrds(0, vectors.size),
|
||||
visitedLimit);
|
||||
switch (vectorEncoding) {
|
||||
case FLOAT32 -> HnswGraphSearcher.search(
|
||||
getTargetVector(),
|
||||
topK,
|
||||
vectors.copy(),
|
||||
vectorEncoding,
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
createRandomAcceptOrds(0, vectors.size),
|
||||
visitedLimit);
|
||||
case BYTE -> HnswGraphSearcher.search(
|
||||
getTargetByteVector(),
|
||||
topK,
|
||||
vectors.copy(),
|
||||
vectorEncoding,
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
createRandomAcceptOrds(0, vectors.size),
|
||||
visitedLimit);
|
||||
};
|
||||
|
||||
assertTrue(nn.incomplete());
|
||||
// The visited count shouldn't exceed the limit
|
||||
assertTrue(nn.visitedCount() <= visitedLimit);
|
||||
|
@ -664,15 +714,27 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
query = randomVector(random(), dim);
|
||||
}
|
||||
actual =
|
||||
HnswGraphSearcher.search(
|
||||
query,
|
||||
100,
|
||||
vectors,
|
||||
vectorEncoding,
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
acceptOrds,
|
||||
Integer.MAX_VALUE);
|
||||
switch (vectorEncoding) {
|
||||
case BYTE -> HnswGraphSearcher.search(
|
||||
bQuery,
|
||||
100,
|
||||
vectors,
|
||||
vectorEncoding,
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
acceptOrds,
|
||||
Integer.MAX_VALUE);
|
||||
case FLOAT32 -> HnswGraphSearcher.search(
|
||||
query,
|
||||
100,
|
||||
vectors,
|
||||
vectorEncoding,
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
acceptOrds,
|
||||
Integer.MAX_VALUE);
|
||||
};
|
||||
|
||||
while (actual.size() > topK) {
|
||||
actual.pop();
|
||||
}
|
||||
|
|
|
@ -41,6 +41,7 @@ import org.apache.lucene.index.VectorSimilarityFunction;
|
|||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.Version;
|
||||
|
||||
/**
|
||||
|
@ -170,6 +171,12 @@ public class TermVectorLeafReader extends LeafReader {
|
|||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(
|
||||
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkIntegrity() throws IOException {}
|
||||
|
||||
|
|
|
@ -1401,6 +1401,12 @@ public class MemoryIndex {
|
|||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(
|
||||
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkIntegrity() throws IOException {
|
||||
// no-op
|
||||
|
|
|
@ -28,10 +28,12 @@ import org.apache.lucene.index.MergeState;
|
|||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.Sorter;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.tests.util.TestUtil;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/** Wraps the default KnnVectorsFormat and provides additional assertions. */
|
||||
public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
|
||||
|
@ -124,7 +126,22 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
|
|||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
FieldInfo fi = fis.fieldInfo(field);
|
||||
assert fi != null && fi.getVectorDimension() > 0;
|
||||
assert fi != null
|
||||
&& fi.getVectorDimension() > 0
|
||||
&& fi.getVectorEncoding() == VectorEncoding.FLOAT32;
|
||||
TopDocs hits = delegate.search(field, target, k, acceptDocs, visitedLimit);
|
||||
assert hits != null;
|
||||
assert hits.scoreDocs.length <= k;
|
||||
return hits;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
FieldInfo fi = fis.fieldInfo(field);
|
||||
assert fi != null
|
||||
&& fi.getVectorDimension() > 0
|
||||
&& fi.getVectorEncoding() == VectorEncoding.BYTE;
|
||||
TopDocs hits = delegate.search(field, target, k, acceptDocs, visitedLimit);
|
||||
assert hits != null;
|
||||
assert hits.scoreDocs.length <= k;
|
||||
|
|
|
@ -41,6 +41,7 @@ import org.apache.lucene.index.Terms;
|
|||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/**
|
||||
* This is a hack to make index sorting fast, with a {@link LeafReader} that always returns merge
|
||||
|
@ -227,6 +228,12 @@ class MergeReaderWrapper extends LeafReader {
|
|||
return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(
|
||||
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
|
||||
return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numDocs() {
|
||||
return in.numDocs();
|
||||
|
|
|
@ -54,6 +54,7 @@ import org.apache.lucene.search.TopDocs;
|
|||
import org.apache.lucene.search.Weight;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.Version;
|
||||
import org.junit.Assert;
|
||||
|
||||
|
@ -234,6 +235,12 @@ public class QueryUtils {
|
|||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(
|
||||
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public FieldInfos getFieldInfos() {
|
||||
return FieldInfos.EMPTY;
|
||||
|
|
Loading…
Reference in New Issue