mirror of
https://github.com/apache/lucene.git
synced 2025-03-06 16:29:30 +00:00
LUCENE-10142: use a better RNG for HNSW vectors
This code makes extensive use of Random, but uses the old legacy java.util.Random, which is slow. Swap in SplittableRandom for better performance.
This commit is contained in:
parent
3dee08a09a
commit
b4fcdd9770
@ -24,7 +24,7 @@ import java.nio.ByteBuffer;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
import java.util.SplittableRandom;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.index.CorruptIndexException;
|
||||
@ -242,7 +242,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
||||
OffHeapVectorValues vectorValues = getOffHeapVectorValues(fieldEntry);
|
||||
|
||||
// use a seed that is fixed for the index so we get reproducible results for the same query
|
||||
final Random random = new Random(checksumSeed);
|
||||
final SplittableRandom random = new SplittableRandom(checksumSeed);
|
||||
NeighborQueue results =
|
||||
HnswGraph.search(
|
||||
target,
|
||||
|
@ -22,7 +22,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
import java.util.SplittableRandom;
|
||||
import org.apache.lucene.index.KnnGraphValues;
|
||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
@ -97,7 +97,7 @@ public final class HnswGraph extends KnnGraphValues {
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
KnnGraphValues graphValues,
|
||||
Bits acceptOrds,
|
||||
Random random)
|
||||
SplittableRandom random)
|
||||
throws IOException {
|
||||
int size = graphValues.size();
|
||||
|
||||
|
@ -20,7 +20,7 @@ package org.apache.lucene.util.hnsw;
|
||||
import java.io.IOException;
|
||||
import java.util.Locale;
|
||||
import java.util.Objects;
|
||||
import java.util.Random;
|
||||
import java.util.SplittableRandom;
|
||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
@ -45,7 +45,7 @@ public final class HnswGraphBuilder {
|
||||
|
||||
private final VectorSimilarityFunction similarityFunction;
|
||||
private final RandomAccessVectorValues vectorValues;
|
||||
private final Random random;
|
||||
private final SplittableRandom random;
|
||||
private final BoundsChecker bound;
|
||||
final HnswGraph hnsw;
|
||||
|
||||
@ -86,7 +86,7 @@ public final class HnswGraphBuilder {
|
||||
this.beamWidth = beamWidth;
|
||||
this.hnsw = new HnswGraph(maxConn);
|
||||
bound = BoundsChecker.create(similarityFunction.reversed);
|
||||
random = new Random(seed);
|
||||
random = new SplittableRandom(seed);
|
||||
scratch = new NeighborArray(Math.max(beamWidth, maxConn + 1));
|
||||
}
|
||||
|
||||
|
@ -24,6 +24,7 @@ import java.util.Arrays;
|
||||
import java.util.HashSet;
|
||||
import java.util.Random;
|
||||
import java.util.Set;
|
||||
import java.util.SplittableRandom;
|
||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90Codec;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90HnswVectorsFormat;
|
||||
@ -141,7 +142,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||
VectorSimilarityFunction.DOT_PRODUCT,
|
||||
hnsw,
|
||||
null,
|
||||
random());
|
||||
new SplittableRandom(random().nextLong()));
|
||||
|
||||
int[] nodes = nn.nodes();
|
||||
assertTrue("Number of found results is not equal to [10].", nodes.length == 10);
|
||||
@ -182,7 +183,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||
VectorSimilarityFunction.DOT_PRODUCT,
|
||||
hnsw,
|
||||
acceptOrds,
|
||||
random());
|
||||
new SplittableRandom(random().nextLong()));
|
||||
int[] nodes = nn.nodes();
|
||||
assertTrue("Number of found results is not equal to [10].", nodes.length == 10);
|
||||
int sum = 0;
|
||||
@ -325,7 +326,14 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||
float[] query = randomVector(random(), dim);
|
||||
NeighborQueue actual =
|
||||
HnswGraph.search(
|
||||
query, topK, 100, vectors, similarityFunction, hnsw, acceptOrds, random());
|
||||
query,
|
||||
topK,
|
||||
100,
|
||||
vectors,
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
acceptOrds,
|
||||
new SplittableRandom(random().nextLong()));
|
||||
NeighborQueue expected = new NeighborQueue(topK, similarityFunction.reversed);
|
||||
for (int j = 0; j < size; j++) {
|
||||
if (vectors.vectorValue(j) != null && (acceptOrds == null || acceptOrds.get(j))) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user