[Abstract]Knn[Byte|Float]VectorQuery tweaks: reduce duplicate method calls (#13528)

* reduce LeafReaderContext.reader()[.maxDoc()] calls in AbstractKnnVectorQuery.getLeafResults

* reduce IndexReader.leaves() calls in AbstractKnnVectorQuery.findSegmentStarts

* reduce LeafReaderContext.reader() calls in Knn(Byte|Float)VectorQuery.approximateSearch
This commit is contained in:
Christine Poerschke 2024-07-01 17:31:02 +01:00 committed by GitHub
parent 3cd406e783
commit 0ad270d8b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 25 additions and 18 deletions

View File

@ -28,6 +28,7 @@ import java.util.concurrent.Callable;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.search.knn.KnnCollectorManager;
@ -120,8 +121,8 @@ abstract class AbstractKnnVectorQuery extends Query {
Weight filterWeight,
TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager)
throws IOException {
Bits liveDocs = ctx.reader().getLiveDocs();
int maxDoc = ctx.reader().maxDoc();
final LeafReader reader = ctx.reader();
final Bits liveDocs = reader.getLiveDocs();
if (filterWeight == null) {
return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, timeLimitingKnnCollectorManager);
@ -132,7 +133,7 @@ abstract class AbstractKnnVectorQuery extends Query {
return NO_RESULTS;
}
BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, maxDoc);
BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, reader.maxDoc());
final int cost = acceptDocs.cardinality();
QueryTimeout queryTimeout = timeLimitingKnnCollectorManager.getQueryTimeout();
@ -267,19 +268,19 @@ abstract class AbstractKnnVectorQuery extends Query {
docs[i] = topK.scoreDocs[i].doc;
scores[i] = topK.scoreDocs[i].score;
}
int[] segmentStarts = findSegmentStarts(reader, docs);
int[] segmentStarts = findSegmentStarts(reader.leaves(), docs);
return new DocAndScoreQuery(docs, scores, maxScore, segmentStarts, reader.getContext().id());
}
static int[] findSegmentStarts(IndexReader reader, int[] docs) {
int[] starts = new int[reader.leaves().size() + 1];
static int[] findSegmentStarts(List<LeafReaderContext> leaves, int[] docs) {
int[] starts = new int[leaves.size() + 1];
starts[starts.length - 1] = docs.length;
if (starts.length == 2) {
return starts;
}
int resultIndex = 0;
for (int i = 1; i < starts.length - 1; i++) {
int upper = reader.leaves().get(i).docBase;
int upper = leaves.get(i).docBase;
resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper);
if (resultIndex < 0) {
resultIndex = -1 - resultIndex;

View File

@ -23,6 +23,7 @@ import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.util.ArrayUtil;
@ -83,24 +84,26 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
KnnCollectorManager knnCollectorManager)
throws IOException {
KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context);
ByteVectorValues byteVectorValues = context.reader().getByteVectorValues(field);
LeafReader reader = context.reader();
ByteVectorValues byteVectorValues = reader.getByteVectorValues(field);
if (byteVectorValues == null) {
ByteVectorValues.checkField(context.reader(), field);
ByteVectorValues.checkField(reader, field);
return NO_RESULTS;
}
if (Math.min(knnCollector.k(), byteVectorValues.size()) == 0) {
return NO_RESULTS;
}
context.reader().searchNearestVectors(field, target, knnCollector, acceptDocs);
reader.searchNearestVectors(field, target, knnCollector, acceptDocs);
TopDocs results = knnCollector.topDocs();
return results != null ? results : NO_RESULTS;
}
@Override
VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException {
ByteVectorValues vectorValues = context.reader().getByteVectorValues(field);
LeafReader reader = context.reader();
ByteVectorValues vectorValues = reader.getByteVectorValues(field);
if (vectorValues == null) {
ByteVectorValues.checkField(context.reader(), field);
ByteVectorValues.checkField(reader, field);
return null;
}
return vectorValues.scorer(target);

View File

@ -23,6 +23,7 @@ import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.util.ArrayUtil;
@ -84,24 +85,26 @@ public class KnnFloatVectorQuery extends AbstractKnnVectorQuery {
KnnCollectorManager knnCollectorManager)
throws IOException {
KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context);
FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(field);
LeafReader reader = context.reader();
FloatVectorValues floatVectorValues = reader.getFloatVectorValues(field);
if (floatVectorValues == null) {
FloatVectorValues.checkField(context.reader(), field);
FloatVectorValues.checkField(reader, field);
return NO_RESULTS;
}
if (Math.min(knnCollector.k(), floatVectorValues.size()) == 0) {
return NO_RESULTS;
}
context.reader().searchNearestVectors(field, target, knnCollector, acceptDocs);
reader.searchNearestVectors(field, target, knnCollector, acceptDocs);
TopDocs results = knnCollector.topDocs();
return results != null ? results : NO_RESULTS;
}
@Override
VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException {
FloatVectorValues vectorValues = context.reader().getFloatVectorValues(field);
LeafReader reader = context.reader();
FloatVectorValues vectorValues = reader.getFloatVectorValues(field);
if (vectorValues == null) {
FloatVectorValues.checkField(context.reader(), field);
FloatVectorValues.checkField(reader, field);
return null;
}
return vectorValues.scorer(target);

View File

@ -216,7 +216,7 @@ public class TestKnnFloatVectorQuery extends BaseKnnVectorQueryTestCase {
maxScore = Math.max(maxScore, scores[i]);
}
IndexReader indexReader = searcher.getIndexReader();
int[] segments = AbstractKnnVectorQuery.findSegmentStarts(indexReader, docs);
int[] segments = AbstractKnnVectorQuery.findSegmentStarts(indexReader.leaves(), docs);
AbstractKnnVectorQuery.DocAndScoreQuery query =
new AbstractKnnVectorQuery.DocAndScoreQuery(