[Abstract]Knn[Byte|Float]VectorQuery tweaks: reduce duplicate method calls (#13528)

* reduce LeafReaderContext.reader()[.maxDoc()] calls in AbstractKnnVectorQuery.getLeafResults

* reduce IndexReader.leaves() calls in AbstractKnnVectorQuery.findSegmentStarts

* reduce LeafReaderContext.reader() calls in Knn(Byte|Float)VectorQuery.approximateSearch
This commit is contained in:
Christine Poerschke 2024-07-01 17:31:02 +01:00 committed by GitHub
parent 3cd406e783
commit 0ad270d8b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 25 additions and 18 deletions

View File

@ -28,6 +28,7 @@ import java.util.concurrent.Callable;
import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout; import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.search.knn.KnnCollectorManager;
@ -120,8 +121,8 @@ abstract class AbstractKnnVectorQuery extends Query {
Weight filterWeight, Weight filterWeight,
TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager) TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager)
throws IOException { throws IOException {
Bits liveDocs = ctx.reader().getLiveDocs(); final LeafReader reader = ctx.reader();
int maxDoc = ctx.reader().maxDoc(); final Bits liveDocs = reader.getLiveDocs();
if (filterWeight == null) { if (filterWeight == null) {
return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, timeLimitingKnnCollectorManager); return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, timeLimitingKnnCollectorManager);
@ -132,7 +133,7 @@ abstract class AbstractKnnVectorQuery extends Query {
return NO_RESULTS; return NO_RESULTS;
} }
BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, maxDoc); BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, reader.maxDoc());
final int cost = acceptDocs.cardinality(); final int cost = acceptDocs.cardinality();
QueryTimeout queryTimeout = timeLimitingKnnCollectorManager.getQueryTimeout(); QueryTimeout queryTimeout = timeLimitingKnnCollectorManager.getQueryTimeout();
@ -267,19 +268,19 @@ abstract class AbstractKnnVectorQuery extends Query {
docs[i] = topK.scoreDocs[i].doc; docs[i] = topK.scoreDocs[i].doc;
scores[i] = topK.scoreDocs[i].score; scores[i] = topK.scoreDocs[i].score;
} }
int[] segmentStarts = findSegmentStarts(reader, docs); int[] segmentStarts = findSegmentStarts(reader.leaves(), docs);
return new DocAndScoreQuery(docs, scores, maxScore, segmentStarts, reader.getContext().id()); return new DocAndScoreQuery(docs, scores, maxScore, segmentStarts, reader.getContext().id());
} }
static int[] findSegmentStarts(IndexReader reader, int[] docs) { static int[] findSegmentStarts(List<LeafReaderContext> leaves, int[] docs) {
int[] starts = new int[reader.leaves().size() + 1]; int[] starts = new int[leaves.size() + 1];
starts[starts.length - 1] = docs.length; starts[starts.length - 1] = docs.length;
if (starts.length == 2) { if (starts.length == 2) {
return starts; return starts;
} }
int resultIndex = 0; int resultIndex = 0;
for (int i = 1; i < starts.length - 1; i++) { for (int i = 1; i < starts.length - 1; i++) {
int upper = reader.leaves().get(i).docBase; int upper = leaves.get(i).docBase;
resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper); resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper);
if (resultIndex < 0) { if (resultIndex < 0) {
resultIndex = -1 - resultIndex; resultIndex = -1 - resultIndex;

View File

@ -23,6 +23,7 @@ import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.ArrayUtil;
@ -83,24 +84,26 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
KnnCollectorManager knnCollectorManager) KnnCollectorManager knnCollectorManager)
throws IOException { throws IOException {
KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context); KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context);
ByteVectorValues byteVectorValues = context.reader().getByteVectorValues(field); LeafReader reader = context.reader();
ByteVectorValues byteVectorValues = reader.getByteVectorValues(field);
if (byteVectorValues == null) { if (byteVectorValues == null) {
ByteVectorValues.checkField(context.reader(), field); ByteVectorValues.checkField(reader, field);
return NO_RESULTS; return NO_RESULTS;
} }
if (Math.min(knnCollector.k(), byteVectorValues.size()) == 0) { if (Math.min(knnCollector.k(), byteVectorValues.size()) == 0) {
return NO_RESULTS; return NO_RESULTS;
} }
context.reader().searchNearestVectors(field, target, knnCollector, acceptDocs); reader.searchNearestVectors(field, target, knnCollector, acceptDocs);
TopDocs results = knnCollector.topDocs(); TopDocs results = knnCollector.topDocs();
return results != null ? results : NO_RESULTS; return results != null ? results : NO_RESULTS;
} }
@Override @Override
VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException { VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException {
ByteVectorValues vectorValues = context.reader().getByteVectorValues(field); LeafReader reader = context.reader();
ByteVectorValues vectorValues = reader.getByteVectorValues(field);
if (vectorValues == null) { if (vectorValues == null) {
ByteVectorValues.checkField(context.reader(), field); ByteVectorValues.checkField(reader, field);
return null; return null;
} }
return vectorValues.scorer(target); return vectorValues.scorer(target);

View File

@ -23,6 +23,7 @@ import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.ArrayUtil;
@ -84,24 +85,26 @@ public class KnnFloatVectorQuery extends AbstractKnnVectorQuery {
KnnCollectorManager knnCollectorManager) KnnCollectorManager knnCollectorManager)
throws IOException { throws IOException {
KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context); KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context);
FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(field); LeafReader reader = context.reader();
FloatVectorValues floatVectorValues = reader.getFloatVectorValues(field);
if (floatVectorValues == null) { if (floatVectorValues == null) {
FloatVectorValues.checkField(context.reader(), field); FloatVectorValues.checkField(reader, field);
return NO_RESULTS; return NO_RESULTS;
} }
if (Math.min(knnCollector.k(), floatVectorValues.size()) == 0) { if (Math.min(knnCollector.k(), floatVectorValues.size()) == 0) {
return NO_RESULTS; return NO_RESULTS;
} }
context.reader().searchNearestVectors(field, target, knnCollector, acceptDocs); reader.searchNearestVectors(field, target, knnCollector, acceptDocs);
TopDocs results = knnCollector.topDocs(); TopDocs results = knnCollector.topDocs();
return results != null ? results : NO_RESULTS; return results != null ? results : NO_RESULTS;
} }
@Override @Override
VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException { VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException {
FloatVectorValues vectorValues = context.reader().getFloatVectorValues(field); LeafReader reader = context.reader();
FloatVectorValues vectorValues = reader.getFloatVectorValues(field);
if (vectorValues == null) { if (vectorValues == null) {
FloatVectorValues.checkField(context.reader(), field); FloatVectorValues.checkField(reader, field);
return null; return null;
} }
return vectorValues.scorer(target); return vectorValues.scorer(target);

View File

@ -216,7 +216,7 @@ public class TestKnnFloatVectorQuery extends BaseKnnVectorQueryTestCase {
maxScore = Math.max(maxScore, scores[i]); maxScore = Math.max(maxScore, scores[i]);
} }
IndexReader indexReader = searcher.getIndexReader(); IndexReader indexReader = searcher.getIndexReader();
int[] segments = AbstractKnnVectorQuery.findSegmentStarts(indexReader, docs); int[] segments = AbstractKnnVectorQuery.findSegmentStarts(indexReader.leaves(), docs);
AbstractKnnVectorQuery.DocAndScoreQuery query = AbstractKnnVectorQuery.DocAndScoreQuery query =
new AbstractKnnVectorQuery.DocAndScoreQuery( new AbstractKnnVectorQuery.DocAndScoreQuery(