mirror of https://github.com/apache/lucene.git
LUCENE-10577: Remove LeafReader#searchNearestVectorsExhaustively (#11756)
This PR removes the recently added function on LeafReader to exhaustively search through vectors, plus the helper function KnnVectorsReader#searchExhaustively. Instead it performs the exact search within KnnVectorQuery, using a new helper class called VectorScorer.
This commit is contained in:
parent
f4146a44e9
commit
09a13aeaf2
|
@ -18,7 +18,6 @@
|
|||
package org.apache.lucene.backward_codecs.lucene90;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
import static org.apache.lucene.search.TopDocsCollector.EMPTY_TOPDOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
|
@ -36,7 +35,6 @@ import org.apache.lucene.index.RandomAccessVectorValues;
|
|||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
|
@ -278,21 +276,6 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
|||
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
if (fieldEntry == null) {
|
||||
// The field does not exist or does not index vectors
|
||||
return EMPTY_TOPDOCS;
|
||||
}
|
||||
|
||||
VectorSimilarityFunction similarityFunction = fieldEntry.similarityFunction;
|
||||
VectorValues vectorValues = getVectorValues(field);
|
||||
|
||||
return exhaustiveSearch(vectorValues, acceptDocs, similarityFunction, target, k);
|
||||
}
|
||||
|
||||
private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException {
|
||||
IndexInput bytesSlice =
|
||||
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
package org.apache.lucene.backward_codecs.lucene91;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
import static org.apache.lucene.search.TopDocsCollector.EMPTY_TOPDOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
|
@ -37,7 +36,6 @@ import org.apache.lucene.index.SegmentReadState;
|
|||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
|
@ -268,21 +266,6 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
|||
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
if (fieldEntry == null) {
|
||||
// The field does not exist or does not index vectors
|
||||
return EMPTY_TOPDOCS;
|
||||
}
|
||||
|
||||
VectorSimilarityFunction similarityFunction = fieldEntry.similarityFunction;
|
||||
VectorValues vectorValues = getVectorValues(field);
|
||||
|
||||
return exhaustiveSearch(vectorValues, acceptDocs, similarityFunction, target, k);
|
||||
}
|
||||
|
||||
private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException {
|
||||
IndexInput bytesSlice =
|
||||
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
package org.apache.lucene.backward_codecs.lucene92;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
import static org.apache.lucene.search.TopDocsCollector.EMPTY_TOPDOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
|
@ -34,7 +33,6 @@ import org.apache.lucene.index.SegmentReadState;
|
|||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
|
@ -262,21 +260,6 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
|
|||
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
if (fieldEntry == null) {
|
||||
// The field does not exist or does not index vectors
|
||||
return EMPTY_TOPDOCS;
|
||||
}
|
||||
|
||||
VectorSimilarityFunction similarityFunction = fieldEntry.similarityFunction;
|
||||
VectorValues vectorValues = getVectorValues(field);
|
||||
|
||||
return exhaustiveSearch(vectorValues, acceptDocs, similarityFunction, target, k);
|
||||
}
|
||||
|
||||
/** Get knn graph values; used for testing */
|
||||
public HnswGraph getGraph(String field) throws IOException {
|
||||
FieldInfo info = fieldInfos.fieldInfo(field);
|
||||
|
|
|
@ -183,14 +183,6 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
return new TopDocs(new TotalHits(numVisited, relation), topScoreDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||
FieldInfo info = readState.fieldInfos.fieldInfo(field);
|
||||
VectorSimilarityFunction vectorSimilarity = info.getVectorSimilarityFunction();
|
||||
return exhaustiveSearch(getVectorValues(field), acceptDocs, vectorSimilarity, target, k);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkIntegrity() throws IOException {
|
||||
IndexInput clone = dataIn.clone();
|
||||
|
|
|
@ -21,7 +21,6 @@ import java.io.IOException;
|
|||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TopDocsCollector;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
@ -105,12 +104,6 @@ public abstract class KnnVectorsFormat implements NamedSPILoader.NamedSPI {
|
|||
return TopDocsCollector.EMPTY_TOPDOCS;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
|
||||
return TopDocsCollector.EMPTY_TOPDOCS;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {}
|
||||
|
||||
|
|
|
@ -20,16 +20,12 @@ package org.apache.lucene.codecs;
|
|||
import java.io.Closeable;
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.HitQueue;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/** Reads vectors from an index. */
|
||||
public abstract class KnnVectorsReader implements Closeable, Accountable {
|
||||
|
@ -84,34 +80,6 @@ public abstract class KnnVectorsReader implements Closeable, Accountable {
|
|||
public abstract TopDocs search(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
|
||||
|
||||
/**
|
||||
* Return the k nearest neighbor documents as determined by comparison of their vector values for
|
||||
* this field, to the given vector, by the field's similarity function. The score of each document
|
||||
* is derived from the vector similarity in a way that ensures scores are positive and that a
|
||||
* larger score corresponds to a higher ranking.
|
||||
*
|
||||
* <p>The search is exact, guaranteeing the true k closest neighbors will be returned. Typically
|
||||
* this requires an exhaustive scan of the entire index. It is intended to be used when the number
|
||||
* of potential matches is limited.
|
||||
*
|
||||
* <p>The returned {@link TopDocs} will contain a {@link ScoreDoc} for each nearest neighbor, in
|
||||
* order of their similarity to the query vector (decreasing scores). The {@link TotalHits}
|
||||
* contains the number of documents visited during the search. If the search stopped early because
|
||||
* it hit {@code visitedLimit}, it is indicated through the relation {@code
|
||||
* TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO}.
|
||||
*
|
||||
* <p>The behavior is undefined if the given field doesn't have KNN vectors enabled on its {@link
|
||||
* FieldInfo}. The return value is never {@code null}.
|
||||
*
|
||||
* @param field the vector field to search
|
||||
* @param target the vector-valued query
|
||||
* @param k the number of docs to return
|
||||
* @param acceptDocs {@link DocIdSetIterator} that represents the allowed documents to match.
|
||||
* @return the k nearest neighbor documents, along with their (similarity-specific) scores.
|
||||
*/
|
||||
public abstract TopDocs searchExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException;
|
||||
|
||||
/**
|
||||
* Returns an instance optimized for merging. This instance may only be consumed in the thread
|
||||
* that called {@link #getMergeInstance()}.
|
||||
|
@ -121,67 +89,4 @@ public abstract class KnnVectorsReader implements Closeable, Accountable {
|
|||
public KnnVectorsReader getMergeInstance() {
|
||||
return this;
|
||||
}
|
||||
|
||||
/** {@link #searchExhaustively} */
|
||||
protected static TopDocs exhaustiveSearch(
|
||||
VectorValues vectorValues,
|
||||
DocIdSetIterator acceptDocs,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
float[] target,
|
||||
int k)
|
||||
throws IOException {
|
||||
HitQueue queue = new HitQueue(k, true);
|
||||
ScoreDoc topDoc = queue.top();
|
||||
int doc;
|
||||
while ((doc = acceptDocs.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
|
||||
int vectorDoc = vectorValues.advance(doc);
|
||||
assert vectorDoc == doc;
|
||||
float score = similarityFunction.compare(vectorValues.vectorValue(), target);
|
||||
if (score >= topDoc.score) {
|
||||
topDoc.score = score;
|
||||
topDoc.doc = doc;
|
||||
topDoc = queue.updateTop();
|
||||
}
|
||||
}
|
||||
return topDocsFromHitQueue(queue, acceptDocs.cost());
|
||||
}
|
||||
|
||||
/** {@link #searchExhaustively} */
|
||||
protected static TopDocs exhaustiveSearch(
|
||||
VectorValues vectorValues,
|
||||
DocIdSetIterator acceptDocs,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
BytesRef target,
|
||||
int k)
|
||||
throws IOException {
|
||||
HitQueue queue = new HitQueue(k, true);
|
||||
ScoreDoc topDoc = queue.top();
|
||||
int doc;
|
||||
while ((doc = acceptDocs.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
|
||||
int vectorDoc = vectorValues.advance(doc);
|
||||
assert vectorDoc == doc;
|
||||
float score = similarityFunction.compare(vectorValues.binaryValue(), target);
|
||||
if (score >= topDoc.score) {
|
||||
topDoc.score = score;
|
||||
topDoc.doc = doc;
|
||||
topDoc = queue.updateTop();
|
||||
}
|
||||
}
|
||||
return topDocsFromHitQueue(queue, acceptDocs.cost());
|
||||
}
|
||||
|
||||
private static TopDocs topDocsFromHitQueue(HitQueue queue, long numHits) {
|
||||
// Remove any remaining sentinel values
|
||||
while (queue.size() > 0 && queue.top().score < 0) {
|
||||
queue.pop();
|
||||
}
|
||||
|
||||
ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()];
|
||||
for (int i = topScoreDocs.length - 1; i >= 0; i--) {
|
||||
topScoreDocs[i] = queue.pop();
|
||||
}
|
||||
|
||||
TotalHits totalHits = new TotalHits(numHits, TotalHits.Relation.EQUAL_TO);
|
||||
return new TopDocs(totalHits, topScoreDocs);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,8 +18,6 @@
|
|||
package org.apache.lucene.codecs.lucene94;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
import static org.apache.lucene.search.TopDocsCollector.EMPTY_TOPDOCS;
|
||||
import static org.apache.lucene.util.VectorUtil.toBytesRef;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
|
@ -35,7 +33,6 @@ import org.apache.lucene.index.SegmentReadState;
|
|||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
|
@ -284,25 +281,6 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
|||
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
if (fieldEntry == null) {
|
||||
// The field does not exist or does not index vectors
|
||||
return EMPTY_TOPDOCS;
|
||||
}
|
||||
|
||||
VectorSimilarityFunction similarityFunction = fieldEntry.similarityFunction;
|
||||
VectorValues vectorValues = getVectorValues(field);
|
||||
|
||||
return switch (fieldEntry.vectorEncoding) {
|
||||
case BYTE -> exhaustiveSearch(
|
||||
vectorValues, acceptDocs, similarityFunction, toBytesRef(target), k);
|
||||
case FLOAT32 -> exhaustiveSearch(vectorValues, acceptDocs, similarityFunction, target, k);
|
||||
};
|
||||
}
|
||||
|
||||
/** Get knn graph values; used for testing */
|
||||
public HnswGraph getGraph(String field) throws IOException {
|
||||
FieldInfo info = fieldInfos.fieldInfo(field);
|
||||
|
|
|
@ -33,7 +33,6 @@ import org.apache.lucene.index.SegmentReadState;
|
|||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.Sorter;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
|
@ -268,17 +267,6 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
|
|||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||
KnnVectorsReader knnVectorsReader = fields.get(field);
|
||||
if (knnVectorsReader == null) {
|
||||
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
|
||||
} else {
|
||||
return knnVectorsReader.searchExhaustively(field, target, k, acceptDocs);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() throws IOException {
|
||||
IOUtils.close(fields.values());
|
||||
|
|
|
@ -88,12 +88,6 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
|||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
};
|
||||
|
||||
writeField(fieldData.fieldInfo, knnVectorsReader, maxDoc);
|
||||
|
@ -128,12 +122,6 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
|||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorValues getVectorValues(String field) throws IOException {
|
||||
return MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
|
||||
|
|
|
@ -25,7 +25,6 @@ import org.apache.lucene.codecs.NormsProducer;
|
|||
import org.apache.lucene.codecs.PointsReader;
|
||||
import org.apache.lucene.codecs.StoredFieldsReader;
|
||||
import org.apache.lucene.codecs.TermVectorsReader;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
||||
|
@ -236,19 +235,6 @@ public abstract class CodecReader extends LeafReader {
|
|||
return getVectorReader().search(field, target, k, acceptDocs, visitedLimit);
|
||||
}
|
||||
|
||||
@Override
|
||||
public final TopDocs searchNearestVectorsExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||
ensureOpen();
|
||||
FieldInfo fi = getFieldInfos().fieldInfo(field);
|
||||
if (fi == null || fi.getVectorDimension() == 0) {
|
||||
// Field does not exist or does not index vectors
|
||||
return null;
|
||||
}
|
||||
|
||||
return getVectorReader().searchExhaustively(field, target, k, acceptDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doClose() throws IOException {}
|
||||
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
package org.apache.lucene.index;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
||||
|
@ -59,12 +58,6 @@ abstract class DocValuesLeafReader extends LeafReader {
|
|||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectorsExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public final void checkIntegrity() throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
|
|
|
@ -18,7 +18,6 @@ package org.apache.lucene.index;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.Iterator;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.AttributeSource;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
@ -358,12 +357,6 @@ public abstract class FilterLeafReader extends LeafReader {
|
|||
return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectorsExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||
return in.searchNearestVectorsExhaustively(field, target, k, acceptDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Fields getTermVectors(int docID) throws IOException {
|
||||
ensureOpen();
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
package org.apache.lucene.index;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
|
@ -236,30 +235,6 @@ public abstract class LeafReader extends IndexReader {
|
|||
public abstract TopDocs searchNearestVectors(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
|
||||
|
||||
/**
|
||||
* Return the k nearest neighbor documents as determined by comparison of their vector values for
|
||||
* this field, to the given vector, by the field's similarity function. The score of each document
|
||||
* is derived from the vector similarity in a way that ensures scores are positive and that a
|
||||
* larger score corresponds to a higher ranking.
|
||||
*
|
||||
* <p>The search is exact, meaning the results are guaranteed to be the true k closest neighbors.
|
||||
* This typically requires an exhaustive scan of all candidate documents.
|
||||
*
|
||||
* <p>The returned {@link TopDocs} will contain a {@link ScoreDoc} for each nearest neighbor,
|
||||
* sorted in order of their similarity to the query vector (decreasing scores). The {@link
|
||||
* TotalHits} contains the number of documents visited during the search.
|
||||
*
|
||||
* @param field the vector field to search
|
||||
* @param target the vector-valued query
|
||||
* @param k the number of docs to return
|
||||
* @param acceptDocs {@link DocIdSetIterator} that represents the allowed documents to match, or
|
||||
* {@code null} if they are all allowed to match.
|
||||
* @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores.
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public abstract TopDocs searchNearestVectorsExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException;
|
||||
|
||||
/**
|
||||
* Get the {@link FieldInfos} describing all fields in this reader.
|
||||
*
|
||||
|
|
|
@ -26,7 +26,6 @@ import java.util.Objects;
|
|||
import java.util.Set;
|
||||
import java.util.SortedMap;
|
||||
import java.util.TreeMap;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.Sort;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
@ -404,16 +403,6 @@ public class ParallelLeafReader extends LeafReader {
|
|||
: reader.searchNearestVectors(fieldName, target, k, acceptDocs, visitedLimit);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectorsExhaustively(
|
||||
String fieldName, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||
ensureOpen();
|
||||
LeafReader reader = fieldToReader.get(fieldName);
|
||||
return reader == null
|
||||
? null
|
||||
: reader.searchNearestVectorsExhaustively(fieldName, target, k, acceptDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkIntegrity() throws IOException {
|
||||
ensureOpen();
|
||||
|
|
|
@ -27,7 +27,6 @@ import org.apache.lucene.codecs.NormsProducer;
|
|||
import org.apache.lucene.codecs.PointsReader;
|
||||
import org.apache.lucene.codecs.StoredFieldsReader;
|
||||
import org.apache.lucene.codecs.TermVectorsReader;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
||||
|
@ -173,12 +172,6 @@ public final class SlowCodecReaderWrapper {
|
|||
return reader.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||
return reader.searchNearestVectorsExhaustively(field, target, k, acceptDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkIntegrity() {
|
||||
// We already checkIntegrity the entire reader up front
|
||||
|
|
|
@ -31,7 +31,6 @@ import org.apache.lucene.codecs.NormsProducer;
|
|||
import org.apache.lucene.codecs.PointsReader;
|
||||
import org.apache.lucene.codecs.StoredFieldsReader;
|
||||
import org.apache.lucene.codecs.TermVectorsReader;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.Sort;
|
||||
import org.apache.lucene.search.SortField;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
|
@ -390,12 +389,6 @@ public final class SortingCodecReader extends FilterCodecReader {
|
|||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() throws IOException {
|
||||
delegate.close();
|
||||
|
|
|
@ -24,6 +24,7 @@ import java.util.Comparator;
|
|||
import java.util.Objects;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.document.KnnVectorField;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
|
@ -175,9 +176,42 @@ public class KnnVectorQuery extends Query {
|
|||
}
|
||||
|
||||
// We allow this to be overridden so that tests can check what search strategy is used
|
||||
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptDocs)
|
||||
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
|
||||
throws IOException {
|
||||
return context.reader().searchNearestVectorsExhaustively(field, target, k, acceptDocs);
|
||||
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
|
||||
if (fi == null || fi.getVectorDimension() == 0) {
|
||||
// The field does not exist or does not index vectors
|
||||
return NO_RESULTS;
|
||||
}
|
||||
|
||||
VectorScorer vectorScorer = VectorScorer.create(context, fi, target);
|
||||
HitQueue queue = new HitQueue(k, true);
|
||||
ScoreDoc topDoc = queue.top();
|
||||
int doc;
|
||||
while ((doc = acceptIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
|
||||
boolean advanced = vectorScorer.advanceExact(doc);
|
||||
assert advanced;
|
||||
|
||||
float score = vectorScorer.score();
|
||||
if (score > topDoc.score) {
|
||||
topDoc.score = score;
|
||||
topDoc.doc = doc;
|
||||
topDoc = queue.updateTop();
|
||||
}
|
||||
}
|
||||
|
||||
// Remove any remaining sentinel values
|
||||
while (queue.size() > 0 && queue.top().score < 0) {
|
||||
queue.pop();
|
||||
}
|
||||
|
||||
ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()];
|
||||
for (int i = topScoreDocs.length - 1; i >= 0; i--) {
|
||||
topScoreDocs[i] = queue.pop();
|
||||
}
|
||||
|
||||
TotalHits totalHits = new TotalHits(acceptIterator.cost(), TotalHits.Relation.EQUAL_TO);
|
||||
return new TopDocs(totalHits, topScoreDocs);
|
||||
}
|
||||
|
||||
private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.search;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
|
||||
/**
|
||||
* Computes the similarity score between a given query vector and different document vectors. This
|
||||
* is primarily used by {@link org.apache.lucene.search.KnnVectorQuery} to run an exact, exhaustive
|
||||
* search over the vectors.
|
||||
*/
|
||||
abstract class VectorScorer {
|
||||
protected final VectorValues values;
|
||||
protected final VectorSimilarityFunction similarity;
|
||||
|
||||
/**
|
||||
* Create a new vector scorer instance.
|
||||
*
|
||||
* @param context the reader context
|
||||
* @param fi the FieldInfo for the field containing document vectors
|
||||
* @param query the query vector to compute the similarity for
|
||||
*/
|
||||
static VectorScorer create(LeafReaderContext context, FieldInfo fi, float[] query)
|
||||
throws IOException {
|
||||
VectorValues values = context.reader().getVectorValues(fi.name);
|
||||
VectorSimilarityFunction similarity = fi.getVectorSimilarityFunction();
|
||||
return switch (fi.getVectorEncoding()) {
|
||||
case BYTE -> new ByteVectorScorer(values, query, similarity);
|
||||
case FLOAT32 -> new FloatVectorScorer(values, query, similarity);
|
||||
};
|
||||
}
|
||||
|
||||
VectorScorer(VectorValues values, VectorSimilarityFunction similarity) {
|
||||
this.values = values;
|
||||
this.similarity = similarity;
|
||||
}
|
||||
|
||||
/**
|
||||
* Advance the instance to the given document ID and return true if there is a value for that
|
||||
* document.
|
||||
*/
|
||||
public boolean advanceExact(int doc) throws IOException {
|
||||
int vectorDoc = values.docID();
|
||||
if (vectorDoc < doc) {
|
||||
vectorDoc = values.advance(doc);
|
||||
}
|
||||
return vectorDoc == doc;
|
||||
}
|
||||
|
||||
/** Compute the similarity score for the current document. */
|
||||
abstract float score() throws IOException;
|
||||
|
||||
private static class ByteVectorScorer extends VectorScorer {
|
||||
private final BytesRef query;
|
||||
|
||||
protected ByteVectorScorer(
|
||||
VectorValues values, float[] query, VectorSimilarityFunction similarity) {
|
||||
super(values, similarity);
|
||||
this.query = VectorUtil.toBytesRef(query);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return similarity.compare(query, values.binaryValue());
|
||||
}
|
||||
}
|
||||
|
||||
private static class FloatVectorScorer extends VectorScorer {
|
||||
private final float[] query;
|
||||
|
||||
protected FloatVectorScorer(
|
||||
VectorValues values, float[] query, VectorSimilarityFunction similarity) {
|
||||
super(values, similarity);
|
||||
this.query = query;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return similarity.compare(query, values.vectorValue());
|
||||
}
|
||||
}
|
||||
}
|
|
@ -25,7 +25,6 @@ import java.util.concurrent.LinkedBlockingQueue;
|
|||
import java.util.concurrent.ThreadPoolExecutor;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.MatchAllDocsQuery;
|
||||
import org.apache.lucene.search.Query;
|
||||
|
@ -118,12 +117,6 @@ public class TestSegmentToThreadMapping extends LuceneTestCase {
|
|||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectorsExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doClose() {}
|
||||
|
||||
|
|
|
@ -631,6 +631,48 @@ public class TestKnnVectorQuery extends LuceneTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
/** Tests filtering when all vectors have the same score. */
|
||||
public void testFilterWithSameScore() throws IOException {
|
||||
int numDocs = 100;
|
||||
int dimension = atLeast(5);
|
||||
try (Directory d = newDirectory()) {
|
||||
// Always use the default kNN format to have predictable behavior around when it hits
|
||||
// visitedLimit. This is fine since the test targets KnnVectorQuery logic, not the kNN format
|
||||
// implementation.
|
||||
IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec());
|
||||
IndexWriter w = new IndexWriter(d, iwc);
|
||||
float[] vector = randomVector(dimension);
|
||||
for (int i = 0; i < numDocs; i++) {
|
||||
Document doc = new Document();
|
||||
doc.add(new KnnVectorField("field", vector));
|
||||
doc.add(new IntPoint("tag", i));
|
||||
w.addDocument(doc);
|
||||
}
|
||||
w.forceMerge(1);
|
||||
w.close();
|
||||
|
||||
try (DirectoryReader reader = DirectoryReader.open(d)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
int lower = random().nextInt(50);
|
||||
int size = 5;
|
||||
|
||||
// Test a restrictive filter, which usually performs exact search
|
||||
Query filter1 = IntPoint.newRangeQuery("tag", lower, lower + 6);
|
||||
TopDocs results =
|
||||
searcher.search(
|
||||
new KnnVectorQuery("field", randomVector(dimension), size, filter1), size);
|
||||
assertEquals(size, results.scoreDocs.length);
|
||||
|
||||
// Test an unrestrictive filter, which usually performs approximate search
|
||||
Query filter2 = IntPoint.newRangeQuery("tag", lower, numDocs);
|
||||
results =
|
||||
searcher.search(
|
||||
new KnnVectorQuery("field", randomVector(dimension), size, filter2), size);
|
||||
assertEquals(size, results.scoreDocs.length);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testDeletes() throws IOException {
|
||||
try (Directory dir = newDirectory();
|
||||
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.search;
|
||||
|
||||
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.KnnVectorField;
|
||||
import org.apache.lucene.document.StringField;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
public class TestVectorScorer extends LuceneTestCase {
|
||||
|
||||
public void testFindAll() throws IOException {
|
||||
try (Directory indexStore =
|
||||
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
|
||||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
assert reader.leaves().size() == 1;
|
||||
LeafReaderContext context = reader.leaves().get(0);
|
||||
FieldInfo fieldInfo = context.reader().getFieldInfos().fieldInfo("field");
|
||||
VectorScorer vectorScorer = VectorScorer.create(context, fieldInfo, new float[] {1, 2});
|
||||
|
||||
int numDocs = 0;
|
||||
for (int i = 0; i < reader.maxDoc(); i++) {
|
||||
if (vectorScorer.advanceExact(i)) {
|
||||
numDocs++;
|
||||
}
|
||||
}
|
||||
assertEquals(3, numDocs);
|
||||
}
|
||||
}
|
||||
|
||||
/** Creates a new directory and adds documents with the given vectors as kNN vector fields */
|
||||
private Directory getIndexStore(String field, float[]... contents) throws IOException {
|
||||
Directory indexStore = newDirectory();
|
||||
RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore);
|
||||
VectorEncoding encoding =
|
||||
VectorEncoding.values()[random().nextInt(VectorEncoding.values().length)];
|
||||
for (int i = 0; i < contents.length; ++i) {
|
||||
Document doc = new Document();
|
||||
if (encoding == VectorEncoding.BYTE) {
|
||||
BytesRef v = new BytesRef(new byte[contents[i].length]);
|
||||
for (int j = 0; j < v.length; j++) {
|
||||
v.bytes[j] = (byte) contents[i][j];
|
||||
}
|
||||
doc.add(new KnnVectorField(field, v, EUCLIDEAN));
|
||||
} else {
|
||||
doc.add(new KnnVectorField(field, contents[i]));
|
||||
}
|
||||
doc.add(new StringField("id", "id" + i, Field.Store.YES));
|
||||
writer.addDocument(doc);
|
||||
}
|
||||
// Add some documents without a vector
|
||||
for (int i = 0; i < 5; i++) {
|
||||
Document doc = new Document();
|
||||
doc.add(new StringField("other", "value", Field.Store.NO));
|
||||
writer.addDocument(doc);
|
||||
}
|
||||
writer.forceMerge(1);
|
||||
writer.close();
|
||||
return indexStore;
|
||||
}
|
||||
}
|
|
@ -37,7 +37,6 @@ import org.apache.lucene.index.Terms;
|
|||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.Version;
|
||||
|
@ -169,12 +168,6 @@ public class TermVectorLeafReader extends LeafReader {
|
|||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectorsExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkIntegrity() throws IOException {}
|
||||
|
||||
|
|
|
@ -37,7 +37,6 @@ import org.apache.lucene.document.FieldType;
|
|||
import org.apache.lucene.index.*;
|
||||
import org.apache.lucene.search.Collector;
|
||||
import org.apache.lucene.search.CollectorManager;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.Scorable;
|
||||
|
@ -1374,12 +1373,6 @@ public class MemoryIndex {
|
|||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectorsExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkIntegrity() throws IOException {
|
||||
// no-op
|
||||
|
|
|
@ -29,7 +29,6 @@ import org.apache.lucene.index.SegmentReadState;
|
|||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.Sorter;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.tests.util.TestUtil;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
@ -132,18 +131,6 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
|
|||
return hits;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||
FieldInfo fi = fis.fieldInfo(field);
|
||||
assert fi != null && fi.getVectorDimension() > 0;
|
||||
assert acceptDocs != null;
|
||||
TopDocs hits = delegate.searchExhaustively(field, target, k, acceptDocs);
|
||||
assert hits != null;
|
||||
assert hits.scoreDocs.length <= k;
|
||||
return hits;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() throws IOException {
|
||||
delegate.close();
|
||||
|
|
|
@ -40,7 +40,6 @@ import org.apache.lucene.index.SortedSetDocValues;
|
|||
import org.apache.lucene.index.StoredFieldVisitor;
|
||||
import org.apache.lucene.index.Terms;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
||||
|
@ -229,12 +228,6 @@ class MergeReaderWrapper extends LeafReader {
|
|||
return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectorsExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||
return in.searchNearestVectorsExhaustively(field, target, k, acceptDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numDocs() {
|
||||
return in.numDocs();
|
||||
|
|
|
@ -233,12 +233,6 @@ public class QueryUtils {
|
|||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectorsExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public FieldInfos getFieldInfos() {
|
||||
return FieldInfos.EMPTY;
|
||||
|
|
Loading…
Reference in New Issue