LUCENE-10142: use a better RNG for HNSW vectors

This code makes extensive use of Random, but uses the old legacy
java.util.Random, which is slow. Swap in SplittableRandom for better
performance.
This commit is contained in:
Robert Muir 2021-10-02 15:23:28 -04:00
parent 3dee08a09a
commit b4fcdd9770
No known key found for this signature in database
GPG Key ID: 817AE1DD322D7ECA
4 changed files with 18 additions and 10 deletions

View File

@ -24,7 +24,7 @@ import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import java.util.SplittableRandom;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.CorruptIndexException;
@ -242,7 +242,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
OffHeapVectorValues vectorValues = getOffHeapVectorValues(fieldEntry);
// use a seed that is fixed for the index so we get reproducible results for the same query
final Random random = new Random(checksumSeed);
final SplittableRandom random = new SplittableRandom(checksumSeed);
NeighborQueue results =
HnswGraph.search(
target,

View File

@ -22,7 +22,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.SplittableRandom;
import org.apache.lucene.index.KnnGraphValues;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
@ -97,7 +97,7 @@ public final class HnswGraph extends KnnGraphValues {
VectorSimilarityFunction similarityFunction,
KnnGraphValues graphValues,
Bits acceptOrds,
Random random)
SplittableRandom random)
throws IOException {
int size = graphValues.size();

View File

@ -20,7 +20,7 @@ package org.apache.lucene.util.hnsw;
import java.io.IOException;
import java.util.Locale;
import java.util.Objects;
import java.util.Random;
import java.util.SplittableRandom;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorSimilarityFunction;
@ -45,7 +45,7 @@ public final class HnswGraphBuilder {
private final VectorSimilarityFunction similarityFunction;
private final RandomAccessVectorValues vectorValues;
private final Random random;
private final SplittableRandom random;
private final BoundsChecker bound;
final HnswGraph hnsw;
@ -86,7 +86,7 @@ public final class HnswGraphBuilder {
this.beamWidth = beamWidth;
this.hnsw = new HnswGraph(maxConn);
bound = BoundsChecker.create(similarityFunction.reversed);
random = new Random(seed);
random = new SplittableRandom(seed);
scratch = new NeighborArray(Math.max(beamWidth, maxConn + 1));
}

View File

@ -24,6 +24,7 @@ import java.util.Arrays;
import java.util.HashSet;
import java.util.Random;
import java.util.Set;
import java.util.SplittableRandom;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90Codec;
import org.apache.lucene.codecs.lucene90.Lucene90HnswVectorsFormat;
@ -141,7 +142,7 @@ public class TestHnswGraph extends LuceneTestCase {
VectorSimilarityFunction.DOT_PRODUCT,
hnsw,
null,
random());
new SplittableRandom(random().nextLong()));
int[] nodes = nn.nodes();
assertTrue("Number of found results is not equal to [10].", nodes.length == 10);
@ -182,7 +183,7 @@ public class TestHnswGraph extends LuceneTestCase {
VectorSimilarityFunction.DOT_PRODUCT,
hnsw,
acceptOrds,
random());
new SplittableRandom(random().nextLong()));
int[] nodes = nn.nodes();
assertTrue("Number of found results is not equal to [10].", nodes.length == 10);
int sum = 0;
@ -325,7 +326,14 @@ public class TestHnswGraph extends LuceneTestCase {
float[] query = randomVector(random(), dim);
NeighborQueue actual =
HnswGraph.search(
query, topK, 100, vectors, similarityFunction, hnsw, acceptOrds, random());
query,
topK,
100,
vectors,
similarityFunction,
hnsw,
acceptOrds,
new SplittableRandom(random().nextLong()));
NeighborQueue expected = new NeighborQueue(topK, similarityFunction.reversed);
for (int j = 0; j < size; j++) {
if (vectors.vectorValue(j) != null && (acceptOrds == null || acceptOrds.get(j))) {