Lucene99HnswVectorsReader.search float-vs-byte variants: reduce code duplication (#13529)

* Lucene99HnswVectorsReader.search float-vs-byte variants: reduce code duplication

* action review feedback: use org.apache.lucene.util.IOSupplier
This commit is contained in:
Christine Poerschke 2024-07-01 17:32:04 +01:00 committed by GitHub
parent 0ad270d8b0
commit f4cd4b46fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 24 additions and 29 deletions

View File

@ -45,6 +45,7 @@ import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.store.ReadAdvice;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.IOSupplier;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.HnswGraph;
@ -248,45 +249,39 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry.size() == 0
|| knnCollector.k() == 0
|| fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return;
}
final RandomVectorScorer scorer = flatVectorsReader.getRandomVectorScorer(field, target);
final KnnCollector collector =
new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc);
final Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs);
if (knnCollector.k() < scorer.maxOrd()) {
HnswGraphSearcher.search(scorer, collector, getGraph(fieldEntry), acceptedOrds);
} else {
// if k is larger than the number of vectors, we can just iterate over all vectors
// and collect them
for (int i = 0; i < scorer.maxOrd(); i++) {
if (acceptedOrds == null || acceptedOrds.get(i)) {
if (knnCollector.earlyTerminated()) {
break;
}
knnCollector.incVisitedCount(1);
knnCollector.collect(scorer.ordToDoc(i), scorer.score(i));
}
}
}
search(
fields.get(field),
knnCollector,
acceptDocs,
VectorEncoding.FLOAT32,
() -> flatVectorsReader.getRandomVectorScorer(field, target));
}
@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);
search(
fields.get(field),
knnCollector,
acceptDocs,
VectorEncoding.BYTE,
() -> flatVectorsReader.getRandomVectorScorer(field, target));
}
private void search(
FieldEntry fieldEntry,
KnnCollector knnCollector,
Bits acceptDocs,
VectorEncoding vectorEncoding,
IOSupplier<RandomVectorScorer> scorerSupplier)
throws IOException {
if (fieldEntry.size() == 0
|| knnCollector.k() == 0
|| fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|| fieldEntry.vectorEncoding != vectorEncoding) {
return;
}
final RandomVectorScorer scorer = flatVectorsReader.getRandomVectorScorer(field, target);
final RandomVectorScorer scorer = scorerSupplier.get();
final KnnCollector collector =
new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc);
final Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs);