mirror of https://github.com/apache/lucene.git
Lucene99HnswVectorsReader.search float-vs-byte variants: reduce code duplication (#13529)
* Lucene99HnswVectorsReader.search float-vs-byte variants: reduce code duplication * action review feedback: use org.apache.lucene.util.IOSupplier
This commit is contained in:
parent
0ad270d8b0
commit
f4cd4b46fc
|
@ -45,6 +45,7 @@ import org.apache.lucene.store.IndexInput;
|
|||
import org.apache.lucene.store.RandomAccessInput;
|
||||
import org.apache.lucene.store.ReadAdvice;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.IOSupplier;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
|
@ -248,45 +249,39 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
|
|||
@Override
|
||||
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
|
||||
throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
|
||||
if (fieldEntry.size() == 0
|
||||
|| knnCollector.k() == 0
|
||||
|| fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
||||
return;
|
||||
}
|
||||
final RandomVectorScorer scorer = flatVectorsReader.getRandomVectorScorer(field, target);
|
||||
final KnnCollector collector =
|
||||
new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc);
|
||||
final Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs);
|
||||
if (knnCollector.k() < scorer.maxOrd()) {
|
||||
HnswGraphSearcher.search(scorer, collector, getGraph(fieldEntry), acceptedOrds);
|
||||
} else {
|
||||
// if k is larger than the number of vectors, we can just iterate over all vectors
|
||||
// and collect them
|
||||
for (int i = 0; i < scorer.maxOrd(); i++) {
|
||||
if (acceptedOrds == null || acceptedOrds.get(i)) {
|
||||
if (knnCollector.earlyTerminated()) {
|
||||
break;
|
||||
}
|
||||
knnCollector.incVisitedCount(1);
|
||||
knnCollector.collect(scorer.ordToDoc(i), scorer.score(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
search(
|
||||
fields.get(field),
|
||||
knnCollector,
|
||||
acceptDocs,
|
||||
VectorEncoding.FLOAT32,
|
||||
() -> flatVectorsReader.getRandomVectorScorer(field, target));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
|
||||
throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
search(
|
||||
fields.get(field),
|
||||
knnCollector,
|
||||
acceptDocs,
|
||||
VectorEncoding.BYTE,
|
||||
() -> flatVectorsReader.getRandomVectorScorer(field, target));
|
||||
}
|
||||
|
||||
private void search(
|
||||
FieldEntry fieldEntry,
|
||||
KnnCollector knnCollector,
|
||||
Bits acceptDocs,
|
||||
VectorEncoding vectorEncoding,
|
||||
IOSupplier<RandomVectorScorer> scorerSupplier)
|
||||
throws IOException {
|
||||
|
||||
if (fieldEntry.size() == 0
|
||||
|| knnCollector.k() == 0
|
||||
|| fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
||||
|| fieldEntry.vectorEncoding != vectorEncoding) {
|
||||
return;
|
||||
}
|
||||
final RandomVectorScorer scorer = flatVectorsReader.getRandomVectorScorer(field, target);
|
||||
final RandomVectorScorer scorer = scorerSupplier.get();
|
||||
final KnnCollector collector =
|
||||
new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc);
|
||||
final Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs);
|
||||
|
|
Loading…
Reference in New Issue