LUCENE-10577: Remove LeafReader#searchNearestVectorsExhaustively (#11756)

This PR removes the recently added function on LeafReader to exhaustively search
through vectors, plus the helper function KnnVectorsReader#searchExhaustively.
Instead it performs the exact search within KnnVectorQuery, using a new helper
class called VectorScorer.
This commit is contained in:
Julie Tibshirani 2022-09-08 12:15:02 -07:00 committed by GitHub
parent f4146a44e9
commit 09a13aeaf2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 267 additions and 334 deletions

View File

@ -18,7 +18,6 @@
package org.apache.lucene.backward_codecs.lucene90; package org.apache.lucene.backward_codecs.lucene90;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.search.TopDocsCollector.EMPTY_TOPDOCS;
import java.io.IOException; import java.io.IOException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
@ -36,7 +35,6 @@ import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.TotalHits;
@ -278,21 +276,6 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs); return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
} }
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
// The field does not exist or does not index vectors
return EMPTY_TOPDOCS;
}
VectorSimilarityFunction similarityFunction = fieldEntry.similarityFunction;
VectorValues vectorValues = getVectorValues(field);
return exhaustiveSearch(vectorValues, acceptDocs, similarityFunction, target, k);
}
private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException { private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException {
IndexInput bytesSlice = IndexInput bytesSlice =
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength); vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);

View File

@ -18,7 +18,6 @@
package org.apache.lucene.backward_codecs.lucene91; package org.apache.lucene.backward_codecs.lucene91;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.search.TopDocsCollector.EMPTY_TOPDOCS;
import java.io.IOException; import java.io.IOException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
@ -37,7 +36,6 @@ import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.TotalHits;
@ -268,21 +266,6 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs); return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
} }
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
// The field does not exist or does not index vectors
return EMPTY_TOPDOCS;
}
VectorSimilarityFunction similarityFunction = fieldEntry.similarityFunction;
VectorValues vectorValues = getVectorValues(field);
return exhaustiveSearch(vectorValues, acceptDocs, similarityFunction, target, k);
}
private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException { private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException {
IndexInput bytesSlice = IndexInput bytesSlice =
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength); vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);

View File

@ -18,7 +18,6 @@
package org.apache.lucene.backward_codecs.lucene92; package org.apache.lucene.backward_codecs.lucene92;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.search.TopDocsCollector.EMPTY_TOPDOCS;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
@ -34,7 +33,6 @@ import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.TotalHits;
@ -262,21 +260,6 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs); return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
} }
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
// The field does not exist or does not index vectors
return EMPTY_TOPDOCS;
}
VectorSimilarityFunction similarityFunction = fieldEntry.similarityFunction;
VectorValues vectorValues = getVectorValues(field);
return exhaustiveSearch(vectorValues, acceptDocs, similarityFunction, target, k);
}
/** Get knn graph values; used for testing */ /** Get knn graph values; used for testing */
public HnswGraph getGraph(String field) throws IOException { public HnswGraph getGraph(String field) throws IOException {
FieldInfo info = fieldInfos.fieldInfo(field); FieldInfo info = fieldInfos.fieldInfo(field);

View File

@ -183,14 +183,6 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
return new TopDocs(new TotalHits(numVisited, relation), topScoreDocs); return new TopDocs(new TotalHits(numVisited, relation), topScoreDocs);
} }
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
FieldInfo info = readState.fieldInfos.fieldInfo(field);
VectorSimilarityFunction vectorSimilarity = info.getVectorSimilarityFunction();
return exhaustiveSearch(getVectorValues(field), acceptDocs, vectorSimilarity, target, k);
}
@Override @Override
public void checkIntegrity() throws IOException { public void checkIntegrity() throws IOException {
IndexInput clone = dataIn.clone(); IndexInput clone = dataIn.clone();

View File

@ -21,7 +21,6 @@ import java.io.IOException;
import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopDocsCollector; import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.util.Bits; import org.apache.lucene.util.Bits;
@ -105,12 +104,6 @@ public abstract class KnnVectorsFormat implements NamedSPILoader.NamedSPI {
return TopDocsCollector.EMPTY_TOPDOCS; return TopDocsCollector.EMPTY_TOPDOCS;
} }
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
return TopDocsCollector.EMPTY_TOPDOCS;
}
@Override @Override
public void close() {} public void close() {}

View File

@ -20,16 +20,12 @@ package org.apache.lucene.codecs;
import java.io.Closeable; import java.io.Closeable;
import java.io.IOException; import java.io.IOException;
import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.Accountable; import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.Bits; import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
/** Reads vectors from an index. */ /** Reads vectors from an index. */
public abstract class KnnVectorsReader implements Closeable, Accountable { public abstract class KnnVectorsReader implements Closeable, Accountable {
@ -84,34 +80,6 @@ public abstract class KnnVectorsReader implements Closeable, Accountable {
public abstract TopDocs search( public abstract TopDocs search(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException; String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
/**
* Return the k nearest neighbor documents as determined by comparison of their vector values for
* this field, to the given vector, by the field's similarity function. The score of each document
* is derived from the vector similarity in a way that ensures scores are positive and that a
* larger score corresponds to a higher ranking.
*
* <p>The search is exact, guaranteeing the true k closest neighbors will be returned. Typically
* this requires an exhaustive scan of the entire index. It is intended to be used when the number
* of potential matches is limited.
*
* <p>The returned {@link TopDocs} will contain a {@link ScoreDoc} for each nearest neighbor, in
* order of their similarity to the query vector (decreasing scores). The {@link TotalHits}
* contains the number of documents visited during the search. If the search stopped early because
* it hit {@code visitedLimit}, it is indicated through the relation {@code
* TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO}.
*
* <p>The behavior is undefined if the given field doesn't have KNN vectors enabled on its {@link
* FieldInfo}. The return value is never {@code null}.
*
* @param field the vector field to search
* @param target the vector-valued query
* @param k the number of docs to return
* @param acceptDocs {@link DocIdSetIterator} that represents the allowed documents to match.
* @return the k nearest neighbor documents, along with their (similarity-specific) scores.
*/
public abstract TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException;
/** /**
* Returns an instance optimized for merging. This instance may only be consumed in the thread * Returns an instance optimized for merging. This instance may only be consumed in the thread
* that called {@link #getMergeInstance()}. * that called {@link #getMergeInstance()}.
@ -121,67 +89,4 @@ public abstract class KnnVectorsReader implements Closeable, Accountable {
public KnnVectorsReader getMergeInstance() { public KnnVectorsReader getMergeInstance() {
return this; return this;
} }
/** {@link #searchExhaustively} */
protected static TopDocs exhaustiveSearch(
VectorValues vectorValues,
DocIdSetIterator acceptDocs,
VectorSimilarityFunction similarityFunction,
float[] target,
int k)
throws IOException {
HitQueue queue = new HitQueue(k, true);
ScoreDoc topDoc = queue.top();
int doc;
while ((doc = acceptDocs.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
int vectorDoc = vectorValues.advance(doc);
assert vectorDoc == doc;
float score = similarityFunction.compare(vectorValues.vectorValue(), target);
if (score >= topDoc.score) {
topDoc.score = score;
topDoc.doc = doc;
topDoc = queue.updateTop();
}
}
return topDocsFromHitQueue(queue, acceptDocs.cost());
}
/** {@link #searchExhaustively} */
protected static TopDocs exhaustiveSearch(
VectorValues vectorValues,
DocIdSetIterator acceptDocs,
VectorSimilarityFunction similarityFunction,
BytesRef target,
int k)
throws IOException {
HitQueue queue = new HitQueue(k, true);
ScoreDoc topDoc = queue.top();
int doc;
while ((doc = acceptDocs.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
int vectorDoc = vectorValues.advance(doc);
assert vectorDoc == doc;
float score = similarityFunction.compare(vectorValues.binaryValue(), target);
if (score >= topDoc.score) {
topDoc.score = score;
topDoc.doc = doc;
topDoc = queue.updateTop();
}
}
return topDocsFromHitQueue(queue, acceptDocs.cost());
}
private static TopDocs topDocsFromHitQueue(HitQueue queue, long numHits) {
// Remove any remaining sentinel values
while (queue.size() > 0 && queue.top().score < 0) {
queue.pop();
}
ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()];
for (int i = topScoreDocs.length - 1; i >= 0; i--) {
topScoreDocs[i] = queue.pop();
}
TotalHits totalHits = new TotalHits(numHits, TotalHits.Relation.EQUAL_TO);
return new TopDocs(totalHits, topScoreDocs);
}
} }

View File

@ -18,8 +18,6 @@
package org.apache.lucene.codecs.lucene94; package org.apache.lucene.codecs.lucene94;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.search.TopDocsCollector.EMPTY_TOPDOCS;
import static org.apache.lucene.util.VectorUtil.toBytesRef;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
@ -35,7 +33,6 @@ import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.TotalHits;
@ -284,25 +281,6 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs); return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
} }
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
// The field does not exist or does not index vectors
return EMPTY_TOPDOCS;
}
VectorSimilarityFunction similarityFunction = fieldEntry.similarityFunction;
VectorValues vectorValues = getVectorValues(field);
return switch (fieldEntry.vectorEncoding) {
case BYTE -> exhaustiveSearch(
vectorValues, acceptDocs, similarityFunction, toBytesRef(target), k);
case FLOAT32 -> exhaustiveSearch(vectorValues, acceptDocs, similarityFunction, target, k);
};
}
/** Get knn graph values; used for testing */ /** Get knn graph values; used for testing */
public HnswGraph getGraph(String field) throws IOException { public HnswGraph getGraph(String field) throws IOException {
FieldInfo info = fieldInfos.fieldInfo(field); FieldInfo info = fieldInfos.fieldInfo(field);

View File

@ -33,7 +33,6 @@ import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter; import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.TotalHits;
@ -268,17 +267,6 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
} }
} }
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
KnnVectorsReader knnVectorsReader = fields.get(field);
if (knnVectorsReader == null) {
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
} else {
return knnVectorsReader.searchExhaustively(field, target, k, acceptDocs);
}
}
@Override @Override
public void close() throws IOException { public void close() throws IOException {
IOUtils.close(fields.values()); IOUtils.close(fields.values());

View File

@ -88,12 +88,6 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) { String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
throw new UnsupportedOperationException();
}
}; };
writeField(fieldData.fieldInfo, knnVectorsReader, maxDoc); writeField(fieldData.fieldInfo, knnVectorsReader, maxDoc);
@ -128,12 +122,6 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
throw new UnsupportedOperationException();
}
@Override @Override
public VectorValues getVectorValues(String field) throws IOException { public VectorValues getVectorValues(String field) throws IOException {
return MergedVectorValues.mergeVectorValues(fieldInfo, mergeState); return MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);

View File

@ -25,7 +25,6 @@ import org.apache.lucene.codecs.NormsProducer;
import org.apache.lucene.codecs.PointsReader; import org.apache.lucene.codecs.PointsReader;
import org.apache.lucene.codecs.StoredFieldsReader; import org.apache.lucene.codecs.StoredFieldsReader;
import org.apache.lucene.codecs.TermVectorsReader; import org.apache.lucene.codecs.TermVectorsReader;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits; import org.apache.lucene.util.Bits;
@ -236,19 +235,6 @@ public abstract class CodecReader extends LeafReader {
return getVectorReader().search(field, target, k, acceptDocs, visitedLimit); return getVectorReader().search(field, target, k, acceptDocs, visitedLimit);
} }
@Override
public final TopDocs searchNearestVectorsExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
ensureOpen();
FieldInfo fi = getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
// Field does not exist or does not index vectors
return null;
}
return getVectorReader().searchExhaustively(field, target, k, acceptDocs);
}
@Override @Override
protected void doClose() throws IOException {} protected void doClose() throws IOException {}

View File

@ -18,7 +18,6 @@
package org.apache.lucene.index; package org.apache.lucene.index;
import java.io.IOException; import java.io.IOException;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits; import org.apache.lucene.util.Bits;
@ -59,12 +58,6 @@ abstract class DocValuesLeafReader extends LeafReader {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@Override
public TopDocs searchNearestVectorsExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
throw new UnsupportedOperationException();
}
@Override @Override
public final void checkIntegrity() throws IOException { public final void checkIntegrity() throws IOException {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();

View File

@ -18,7 +18,6 @@ package org.apache.lucene.index;
import java.io.IOException; import java.io.IOException;
import java.util.Iterator; import java.util.Iterator;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.AttributeSource; import org.apache.lucene.util.AttributeSource;
import org.apache.lucene.util.Bits; import org.apache.lucene.util.Bits;
@ -358,12 +357,6 @@ public abstract class FilterLeafReader extends LeafReader {
return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit); return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
} }
@Override
public TopDocs searchNearestVectorsExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
return in.searchNearestVectorsExhaustively(field, target, k, acceptDocs);
}
@Override @Override
public Fields getTermVectors(int docID) throws IOException { public Fields getTermVectors(int docID) throws IOException {
ensureOpen(); ensureOpen();

View File

@ -17,7 +17,6 @@
package org.apache.lucene.index; package org.apache.lucene.index;
import java.io.IOException; import java.io.IOException;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.TotalHits;
@ -236,30 +235,6 @@ public abstract class LeafReader extends IndexReader {
public abstract TopDocs searchNearestVectors( public abstract TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException; String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
/**
* Return the k nearest neighbor documents as determined by comparison of their vector values for
* this field, to the given vector, by the field's similarity function. The score of each document
* is derived from the vector similarity in a way that ensures scores are positive and that a
* larger score corresponds to a higher ranking.
*
* <p>The search is exact, meaning the results are guaranteed to be the true k closest neighbors.
* This typically requires an exhaustive scan of all candidate documents.
*
* <p>The returned {@link TopDocs} will contain a {@link ScoreDoc} for each nearest neighbor,
* sorted in order of their similarity to the query vector (decreasing scores). The {@link
* TotalHits} contains the number of documents visited during the search.
*
* @param field the vector field to search
* @param target the vector-valued query
* @param k the number of docs to return
* @param acceptDocs {@link DocIdSetIterator} that represents the allowed documents to match, or
* {@code null} if they are all allowed to match.
* @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores.
* @lucene.experimental
*/
public abstract TopDocs searchNearestVectorsExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException;
/** /**
* Get the {@link FieldInfos} describing all fields in this reader. * Get the {@link FieldInfos} describing all fields in this reader.
* *

View File

@ -26,7 +26,6 @@ import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.SortedMap; import java.util.SortedMap;
import java.util.TreeMap; import java.util.TreeMap;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Sort; import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits; import org.apache.lucene.util.Bits;
@ -404,16 +403,6 @@ public class ParallelLeafReader extends LeafReader {
: reader.searchNearestVectors(fieldName, target, k, acceptDocs, visitedLimit); : reader.searchNearestVectors(fieldName, target, k, acceptDocs, visitedLimit);
} }
@Override
public TopDocs searchNearestVectorsExhaustively(
String fieldName, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
ensureOpen();
LeafReader reader = fieldToReader.get(fieldName);
return reader == null
? null
: reader.searchNearestVectorsExhaustively(fieldName, target, k, acceptDocs);
}
@Override @Override
public void checkIntegrity() throws IOException { public void checkIntegrity() throws IOException {
ensureOpen(); ensureOpen();

View File

@ -27,7 +27,6 @@ import org.apache.lucene.codecs.NormsProducer;
import org.apache.lucene.codecs.PointsReader; import org.apache.lucene.codecs.PointsReader;
import org.apache.lucene.codecs.StoredFieldsReader; import org.apache.lucene.codecs.StoredFieldsReader;
import org.apache.lucene.codecs.TermVectorsReader; import org.apache.lucene.codecs.TermVectorsReader;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits; import org.apache.lucene.util.Bits;
@ -173,12 +172,6 @@ public final class SlowCodecReaderWrapper {
return reader.searchNearestVectors(field, target, k, acceptDocs, visitedLimit); return reader.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
} }
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
return reader.searchNearestVectorsExhaustively(field, target, k, acceptDocs);
}
@Override @Override
public void checkIntegrity() { public void checkIntegrity() {
// We already checkIntegrity the entire reader up front // We already checkIntegrity the entire reader up front

View File

@ -31,7 +31,6 @@ import org.apache.lucene.codecs.NormsProducer;
import org.apache.lucene.codecs.PointsReader; import org.apache.lucene.codecs.PointsReader;
import org.apache.lucene.codecs.StoredFieldsReader; import org.apache.lucene.codecs.StoredFieldsReader;
import org.apache.lucene.codecs.TermVectorsReader; import org.apache.lucene.codecs.TermVectorsReader;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Sort; import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField; import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
@ -390,12 +389,6 @@ public final class SortingCodecReader extends FilterCodecReader {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
throw new UnsupportedOperationException();
}
@Override @Override
public void close() throws IOException { public void close() throws IOException {
delegate.close(); delegate.close();

View File

@ -24,6 +24,7 @@ import java.util.Comparator;
import java.util.Objects; import java.util.Objects;
import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.document.KnnVectorField; import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.util.BitSet; import org.apache.lucene.util.BitSet;
@ -175,9 +176,42 @@ public class KnnVectorQuery extends Query {
} }
// We allow this to be overridden so that tests can check what search strategy is used // We allow this to be overridden so that tests can check what search strategy is used
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptDocs) protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
throws IOException { throws IOException {
return context.reader().searchNearestVectorsExhaustively(field, target, k, acceptDocs); FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
// The field does not exist or does not index vectors
return NO_RESULTS;
}
VectorScorer vectorScorer = VectorScorer.create(context, fi, target);
HitQueue queue = new HitQueue(k, true);
ScoreDoc topDoc = queue.top();
int doc;
while ((doc = acceptIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
boolean advanced = vectorScorer.advanceExact(doc);
assert advanced;
float score = vectorScorer.score();
if (score > topDoc.score) {
topDoc.score = score;
topDoc.doc = doc;
topDoc = queue.updateTop();
}
}
// Remove any remaining sentinel values
while (queue.size() > 0 && queue.top().score < 0) {
queue.pop();
}
ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()];
for (int i = topScoreDocs.length - 1; i >= 0; i--) {
topScoreDocs[i] = queue.pop();
}
TotalHits totalHits = new TotalHits(acceptIterator.cost(), TotalHits.Relation.EQUAL_TO);
return new TopDocs(totalHits, topScoreDocs);
} }
private Query createRewrittenQuery(IndexReader reader, TopDocs topK) { private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {

View File

@ -0,0 +1,102 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.VectorUtil;
/**
* Computes the similarity score between a given query vector and different document vectors. This
* is primarily used by {@link org.apache.lucene.search.KnnVectorQuery} to run an exact, exhaustive
* search over the vectors.
*/
abstract class VectorScorer {
protected final VectorValues values;
protected final VectorSimilarityFunction similarity;
/**
* Create a new vector scorer instance.
*
* @param context the reader context
* @param fi the FieldInfo for the field containing document vectors
* @param query the query vector to compute the similarity for
*/
static VectorScorer create(LeafReaderContext context, FieldInfo fi, float[] query)
throws IOException {
VectorValues values = context.reader().getVectorValues(fi.name);
VectorSimilarityFunction similarity = fi.getVectorSimilarityFunction();
return switch (fi.getVectorEncoding()) {
case BYTE -> new ByteVectorScorer(values, query, similarity);
case FLOAT32 -> new FloatVectorScorer(values, query, similarity);
};
}
VectorScorer(VectorValues values, VectorSimilarityFunction similarity) {
this.values = values;
this.similarity = similarity;
}
/**
* Advance the instance to the given document ID and return true if there is a value for that
* document.
*/
public boolean advanceExact(int doc) throws IOException {
int vectorDoc = values.docID();
if (vectorDoc < doc) {
vectorDoc = values.advance(doc);
}
return vectorDoc == doc;
}
/** Compute the similarity score for the current document. */
abstract float score() throws IOException;
private static class ByteVectorScorer extends VectorScorer {
private final BytesRef query;
protected ByteVectorScorer(
VectorValues values, float[] query, VectorSimilarityFunction similarity) {
super(values, similarity);
this.query = VectorUtil.toBytesRef(query);
}
@Override
public float score() throws IOException {
return similarity.compare(query, values.binaryValue());
}
}
private static class FloatVectorScorer extends VectorScorer {
private final float[] query;
protected FloatVectorScorer(
VectorValues values, float[] query, VectorSimilarityFunction similarity) {
super(values, similarity);
this.query = query;
}
@Override
public float score() throws IOException {
return similarity.compare(query, values.vectorValue());
}
}
}

View File

@ -25,7 +25,6 @@ import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.Query; import org.apache.lucene.search.Query;
@ -118,12 +117,6 @@ public class TestSegmentToThreadMapping extends LuceneTestCase {
return null; return null;
} }
@Override
public TopDocs searchNearestVectorsExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
return null;
}
@Override @Override
protected void doClose() {} protected void doClose() {}

View File

@ -631,6 +631,48 @@ public class TestKnnVectorQuery extends LuceneTestCase {
} }
} }
/** Tests filtering when all vectors have the same score. */
public void testFilterWithSameScore() throws IOException {
int numDocs = 100;
int dimension = atLeast(5);
try (Directory d = newDirectory()) {
// Always use the default kNN format to have predictable behavior around when it hits
// visitedLimit. This is fine since the test targets KnnVectorQuery logic, not the kNN format
// implementation.
IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec());
IndexWriter w = new IndexWriter(d, iwc);
float[] vector = randomVector(dimension);
for (int i = 0; i < numDocs; i++) {
Document doc = new Document();
doc.add(new KnnVectorField("field", vector));
doc.add(new IntPoint("tag", i));
w.addDocument(doc);
}
w.forceMerge(1);
w.close();
try (DirectoryReader reader = DirectoryReader.open(d)) {
IndexSearcher searcher = newSearcher(reader);
int lower = random().nextInt(50);
int size = 5;
// Test a restrictive filter, which usually performs exact search
Query filter1 = IntPoint.newRangeQuery("tag", lower, lower + 6);
TopDocs results =
searcher.search(
new KnnVectorQuery("field", randomVector(dimension), size, filter1), size);
assertEquals(size, results.scoreDocs.length);
// Test an unrestrictive filter, which usually performs approximate search
Query filter2 = IntPoint.newRangeQuery("tag", lower, numDocs);
results =
searcher.search(
new KnnVectorQuery("field", randomVector(dimension), size, filter2), size);
assertEquals(size, results.scoreDocs.length);
}
}
}
public void testDeletes() throws IOException { public void testDeletes() throws IOException {
try (Directory dir = newDirectory(); try (Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {

View File

@ -0,0 +1,87 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
import java.io.IOException;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.BytesRef;
public class TestVectorScorer extends LuceneTestCase {
public void testFindAll() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
assert reader.leaves().size() == 1;
LeafReaderContext context = reader.leaves().get(0);
FieldInfo fieldInfo = context.reader().getFieldInfos().fieldInfo("field");
VectorScorer vectorScorer = VectorScorer.create(context, fieldInfo, new float[] {1, 2});
int numDocs = 0;
for (int i = 0; i < reader.maxDoc(); i++) {
if (vectorScorer.advanceExact(i)) {
numDocs++;
}
}
assertEquals(3, numDocs);
}
}
/** Creates a new directory and adds documents with the given vectors as kNN vector fields */
private Directory getIndexStore(String field, float[]... contents) throws IOException {
Directory indexStore = newDirectory();
RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore);
VectorEncoding encoding =
VectorEncoding.values()[random().nextInt(VectorEncoding.values().length)];
for (int i = 0; i < contents.length; ++i) {
Document doc = new Document();
if (encoding == VectorEncoding.BYTE) {
BytesRef v = new BytesRef(new byte[contents[i].length]);
for (int j = 0; j < v.length; j++) {
v.bytes[j] = (byte) contents[i][j];
}
doc.add(new KnnVectorField(field, v, EUCLIDEAN));
} else {
doc.add(new KnnVectorField(field, contents[i]));
}
doc.add(new StringField("id", "id" + i, Field.Store.YES));
writer.addDocument(doc);
}
// Add some documents without a vector
for (int i = 0; i < 5; i++) {
Document doc = new Document();
doc.add(new StringField("other", "value", Field.Store.NO));
writer.addDocument(doc);
}
writer.forceMerge(1);
writer.close();
return indexStore;
}
}

View File

@ -37,7 +37,6 @@ import org.apache.lucene.index.Terms;
import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits; import org.apache.lucene.util.Bits;
import org.apache.lucene.util.Version; import org.apache.lucene.util.Version;
@ -169,12 +168,6 @@ public class TermVectorLeafReader extends LeafReader {
return null; return null;
} }
@Override
public TopDocs searchNearestVectorsExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
return null;
}
@Override @Override
public void checkIntegrity() throws IOException {} public void checkIntegrity() throws IOException {}

View File

@ -37,7 +37,6 @@ import org.apache.lucene.document.FieldType;
import org.apache.lucene.index.*; import org.apache.lucene.index.*;
import org.apache.lucene.search.Collector; import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query; import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorable; import org.apache.lucene.search.Scorable;
@ -1374,12 +1373,6 @@ public class MemoryIndex {
return null; return null;
} }
@Override
public TopDocs searchNearestVectorsExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
return null;
}
@Override @Override
public void checkIntegrity() throws IOException { public void checkIntegrity() throws IOException {
// no-op // no-op

View File

@ -29,7 +29,6 @@ import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter; import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.tests.util.TestUtil; import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.Bits; import org.apache.lucene.util.Bits;
@ -132,18 +131,6 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
return hits; return hits;
} }
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
FieldInfo fi = fis.fieldInfo(field);
assert fi != null && fi.getVectorDimension() > 0;
assert acceptDocs != null;
TopDocs hits = delegate.searchExhaustively(field, target, k, acceptDocs);
assert hits != null;
assert hits.scoreDocs.length <= k;
return hits;
}
@Override @Override
public void close() throws IOException { public void close() throws IOException {
delegate.close(); delegate.close();

View File

@ -40,7 +40,6 @@ import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.index.StoredFieldVisitor; import org.apache.lucene.index.StoredFieldVisitor;
import org.apache.lucene.index.Terms; import org.apache.lucene.index.Terms;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits; import org.apache.lucene.util.Bits;
@ -229,12 +228,6 @@ class MergeReaderWrapper extends LeafReader {
return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit); return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
} }
@Override
public TopDocs searchNearestVectorsExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
return in.searchNearestVectorsExhaustively(field, target, k, acceptDocs);
}
@Override @Override
public int numDocs() { public int numDocs() {
return in.numDocs(); return in.numDocs();

View File

@ -233,12 +233,6 @@ public class QueryUtils {
return null; return null;
} }
@Override
public TopDocs searchNearestVectorsExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
return null;
}
@Override @Override
public FieldInfos getFieldInfos() { public FieldInfos getFieldInfos() {
return FieldInfos.EMPTY; return FieldInfos.EMPTY;