mirror of https://github.com/apache/lucene.git
Fix NPE when sampling for quantization in Lucene99HnswScalarQuantizedVectorsFormat (#13027)
When merging `Lucene99HnswScalarQuantizedVectorsFormat` a NPE is possible when deleted documents are present. `ScalarQuantizer#fromVectors` doesn't take deleted documents into account. This means using `FloatVectorValues#size` may actually be larger than the actual size of live documents. Consequently, when iterating for sampling iteration too far is possible and an NPE will be thrown.
This commit is contained in:
parent
3674e779cb
commit
f16007c3ec
|
@ -243,6 +243,13 @@ Other
|
||||||
|
|
||||||
* GITHUB#12934: Cleaning up old references to Lucene/Solr. (Jakub Slowinski)
|
* GITHUB#12934: Cleaning up old references to Lucene/Solr. (Jakub Slowinski)
|
||||||
|
|
||||||
|
======================== Lucene 9.9.2 =======================
|
||||||
|
|
||||||
|
Bug Fixes
|
||||||
|
---------------------
|
||||||
|
|
||||||
|
* GITHUB#13027: Fix NPE when sampling for quantization in Lucene99HnswScalarQuantizedVectorsFormat (Ben Trent)
|
||||||
|
|
||||||
======================== Lucene 9.9.1 =======================
|
======================== Lucene 9.9.1 =======================
|
||||||
|
|
||||||
Bug Fixes
|
Bug Fixes
|
||||||
|
|
|
@ -546,9 +546,20 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
||||||
// merged
|
// merged
|
||||||
// segment view
|
// segment view
|
||||||
if (mergedQuantiles == null || shouldRecomputeQuantiles(mergedQuantiles, quantizationStates)) {
|
if (mergedQuantiles == null || shouldRecomputeQuantiles(mergedQuantiles, quantizationStates)) {
|
||||||
|
int numVectors = 0;
|
||||||
FloatVectorValues vectorValues =
|
FloatVectorValues vectorValues =
|
||||||
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
|
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
|
||||||
mergedQuantiles = ScalarQuantizer.fromVectors(vectorValues, confidenceInterval);
|
// iterate vectorValues and increment numVectors
|
||||||
|
for (int doc = vectorValues.nextDoc();
|
||||||
|
doc != DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
doc = vectorValues.nextDoc()) {
|
||||||
|
numVectors++;
|
||||||
|
}
|
||||||
|
mergedQuantiles =
|
||||||
|
ScalarQuantizer.fromVectors(
|
||||||
|
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState),
|
||||||
|
confidenceInterval,
|
||||||
|
numVectors);
|
||||||
}
|
}
|
||||||
return mergedQuantiles;
|
return mergedQuantiles;
|
||||||
}
|
}
|
||||||
|
@ -638,7 +649,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
||||||
new FloatVectorWrapper(
|
new FloatVectorWrapper(
|
||||||
floatVectors,
|
floatVectors,
|
||||||
fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE),
|
fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE),
|
||||||
confidenceInterval);
|
confidenceInterval,
|
||||||
|
floatVectors.size());
|
||||||
minQuantile = quantizer.getLowerQuantile();
|
minQuantile = quantizer.getLowerQuantile();
|
||||||
maxQuantile = quantizer.getUpperQuantile();
|
maxQuantile = quantizer.getUpperQuantile();
|
||||||
if (infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) {
|
if (infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) {
|
||||||
|
|
|
@ -192,6 +192,53 @@ public class ScalarQuantizer {
|
||||||
|
|
||||||
private static final Random random = new Random(42);
|
private static final Random random = new Random(42);
|
||||||
|
|
||||||
|
static int[] reservoirSampleIndices(int numFloatVecs, int sampleSize) {
|
||||||
|
int[] vectorsToTake = IntStream.range(0, sampleSize).toArray();
|
||||||
|
for (int i = sampleSize; i < numFloatVecs; i++) {
|
||||||
|
int j = random.nextInt(i + 1);
|
||||||
|
if (j < sampleSize) {
|
||||||
|
vectorsToTake[j] = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Arrays.sort(vectorsToTake);
|
||||||
|
return vectorsToTake;
|
||||||
|
}
|
||||||
|
|
||||||
|
static float[] sampleVectors(FloatVectorValues floatVectorValues, int[] vectorsToTake)
|
||||||
|
throws IOException {
|
||||||
|
int dim = floatVectorValues.dimension();
|
||||||
|
float[] values = new float[vectorsToTake.length * dim];
|
||||||
|
int copyOffset = 0;
|
||||||
|
int index = 0;
|
||||||
|
for (int i : vectorsToTake) {
|
||||||
|
while (index <= i) {
|
||||||
|
// We cannot use `advance(docId)` as MergedVectorValues does not support it
|
||||||
|
floatVectorValues.nextDoc();
|
||||||
|
index++;
|
||||||
|
}
|
||||||
|
assert floatVectorValues.docID() != NO_MORE_DOCS;
|
||||||
|
float[] floatVector = floatVectorValues.vectorValue();
|
||||||
|
System.arraycopy(floatVector, 0, values, copyOffset, floatVector.length);
|
||||||
|
copyOffset += dim;
|
||||||
|
}
|
||||||
|
return values;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #fromVectors(FloatVectorValues, float, int)} for details on how the quantiles are
|
||||||
|
* calculated. NOTE: If there are deleted vectors in the index, do not use this method, but
|
||||||
|
* instead use {@link #fromVectors(FloatVectorValues, float, int)}. This is because the
|
||||||
|
* totalVectorCount is used to account for deleted documents when sampling.
|
||||||
|
*/
|
||||||
|
public static ScalarQuantizer fromVectors(
|
||||||
|
FloatVectorValues floatVectorValues, float confidenceInterval) throws IOException {
|
||||||
|
return fromVectors(
|
||||||
|
floatVectorValues,
|
||||||
|
confidenceInterval,
|
||||||
|
floatVectorValues.size(),
|
||||||
|
SCALAR_QUANTIZATION_SAMPLE_SIZE);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This will read the float vector values and calculate the quantiles. If the number of float
|
* This will read the float vector values and calculate the quantiles. If the number of float
|
||||||
* vectors is less than {@link #SCALAR_QUANTIZATION_SAMPLE_SIZE} then all the values will be read
|
* vectors is less than {@link #SCALAR_QUANTIZATION_SAMPLE_SIZE} then all the values will be read
|
||||||
|
@ -201,13 +248,26 @@ public class ScalarQuantizer {
|
||||||
*
|
*
|
||||||
* @param floatVectorValues the float vector values from which to calculate the quantiles
|
* @param floatVectorValues the float vector values from which to calculate the quantiles
|
||||||
* @param confidenceInterval the confidence interval used to calculate the quantiles
|
* @param confidenceInterval the confidence interval used to calculate the quantiles
|
||||||
|
* @param totalVectorCount the total number of live float vectors in the index. This is vital for
|
||||||
|
* accounting for deleted documents when calculating the quantiles.
|
||||||
* @return A new {@link ScalarQuantizer} instance
|
* @return A new {@link ScalarQuantizer} instance
|
||||||
* @throws IOException if there is an error reading the float vector values
|
* @throws IOException if there is an error reading the float vector values
|
||||||
*/
|
*/
|
||||||
public static ScalarQuantizer fromVectors(
|
public static ScalarQuantizer fromVectors(
|
||||||
FloatVectorValues floatVectorValues, float confidenceInterval) throws IOException {
|
FloatVectorValues floatVectorValues, float confidenceInterval, int totalVectorCount)
|
||||||
|
throws IOException {
|
||||||
|
return fromVectors(
|
||||||
|
floatVectorValues, confidenceInterval, totalVectorCount, SCALAR_QUANTIZATION_SAMPLE_SIZE);
|
||||||
|
}
|
||||||
|
|
||||||
|
static ScalarQuantizer fromVectors(
|
||||||
|
FloatVectorValues floatVectorValues,
|
||||||
|
float confidenceInterval,
|
||||||
|
int totalVectorCount,
|
||||||
|
int quantizationSampleSize)
|
||||||
|
throws IOException {
|
||||||
assert 0.9f <= confidenceInterval && confidenceInterval <= 1f;
|
assert 0.9f <= confidenceInterval && confidenceInterval <= 1f;
|
||||||
if (floatVectorValues.size() == 0) {
|
if (totalVectorCount == 0) {
|
||||||
return new ScalarQuantizer(0f, 0f, confidenceInterval);
|
return new ScalarQuantizer(0f, 0f, confidenceInterval);
|
||||||
}
|
}
|
||||||
if (confidenceInterval == 1f) {
|
if (confidenceInterval == 1f) {
|
||||||
|
@ -222,9 +282,9 @@ public class ScalarQuantizer {
|
||||||
return new ScalarQuantizer(min, max, confidenceInterval);
|
return new ScalarQuantizer(min, max, confidenceInterval);
|
||||||
}
|
}
|
||||||
int dim = floatVectorValues.dimension();
|
int dim = floatVectorValues.dimension();
|
||||||
if (floatVectorValues.size() < SCALAR_QUANTIZATION_SAMPLE_SIZE) {
|
if (totalVectorCount <= quantizationSampleSize) {
|
||||||
int copyOffset = 0;
|
int copyOffset = 0;
|
||||||
float[] values = new float[floatVectorValues.size() * dim];
|
float[] values = new float[totalVectorCount * dim];
|
||||||
while (floatVectorValues.nextDoc() != NO_MORE_DOCS) {
|
while (floatVectorValues.nextDoc() != NO_MORE_DOCS) {
|
||||||
float[] floatVector = floatVectorValues.vectorValue();
|
float[] floatVector = floatVectorValues.vectorValue();
|
||||||
System.arraycopy(floatVector, 0, values, copyOffset, floatVector.length);
|
System.arraycopy(floatVector, 0, values, copyOffset, floatVector.length);
|
||||||
|
@ -233,30 +293,10 @@ public class ScalarQuantizer {
|
||||||
float[] upperAndLower = getUpperAndLowerQuantile(values, confidenceInterval);
|
float[] upperAndLower = getUpperAndLowerQuantile(values, confidenceInterval);
|
||||||
return new ScalarQuantizer(upperAndLower[0], upperAndLower[1], confidenceInterval);
|
return new ScalarQuantizer(upperAndLower[0], upperAndLower[1], confidenceInterval);
|
||||||
}
|
}
|
||||||
int numFloatVecs = floatVectorValues.size();
|
int numFloatVecs = totalVectorCount;
|
||||||
// Reservoir sample the vector ordinals we want to read
|
// Reservoir sample the vector ordinals we want to read
|
||||||
float[] values = new float[SCALAR_QUANTIZATION_SAMPLE_SIZE * dim];
|
int[] vectorsToTake = reservoirSampleIndices(numFloatVecs, quantizationSampleSize);
|
||||||
int[] vectorsToTake = IntStream.range(0, SCALAR_QUANTIZATION_SAMPLE_SIZE).toArray();
|
float[] values = sampleVectors(floatVectorValues, vectorsToTake);
|
||||||
for (int i = SCALAR_QUANTIZATION_SAMPLE_SIZE; i < numFloatVecs; i++) {
|
|
||||||
int j = random.nextInt(i + 1);
|
|
||||||
if (j < SCALAR_QUANTIZATION_SAMPLE_SIZE) {
|
|
||||||
vectorsToTake[j] = i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Arrays.sort(vectorsToTake);
|
|
||||||
int copyOffset = 0;
|
|
||||||
int index = 0;
|
|
||||||
for (int i : vectorsToTake) {
|
|
||||||
while (index <= i) {
|
|
||||||
// We cannot use `advance(docId)` as MergedVectorValues does not support it
|
|
||||||
floatVectorValues.nextDoc();
|
|
||||||
index++;
|
|
||||||
}
|
|
||||||
assert floatVectorValues.docID() != NO_MORE_DOCS;
|
|
||||||
float[] floatVector = floatVectorValues.vectorValue();
|
|
||||||
System.arraycopy(floatVector, 0, values, copyOffset, floatVector.length);
|
|
||||||
copyOffset += dim;
|
|
||||||
}
|
|
||||||
float[] upperAndLower = getUpperAndLowerQuantile(values, confidenceInterval);
|
float[] upperAndLower = getUpperAndLowerQuantile(values, confidenceInterval);
|
||||||
return new ScalarQuantizer(upperAndLower[0], upperAndLower[1], confidenceInterval);
|
return new ScalarQuantizer(upperAndLower[0], upperAndLower[1], confidenceInterval);
|
||||||
}
|
}
|
||||||
|
|
|
@ -70,7 +70,8 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
|
||||||
ScalarQuantizer scalarQuantizer =
|
ScalarQuantizer scalarQuantizer =
|
||||||
ScalarQuantizer.fromVectors(
|
ScalarQuantizer.fromVectors(
|
||||||
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, normalize),
|
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, normalize),
|
||||||
confidenceInterval);
|
confidenceInterval,
|
||||||
|
numVectors);
|
||||||
float[] expectedCorrections = new float[numVectors];
|
float[] expectedCorrections = new float[numVectors];
|
||||||
byte[][] expectedVectors = new byte[numVectors][];
|
byte[][] expectedVectors = new byte[numVectors][];
|
||||||
for (int i = 0; i < numVectors; i++) {
|
for (int i = 0; i < numVectors; i++) {
|
||||||
|
|
|
@ -21,6 +21,7 @@ import static org.apache.lucene.util.TestScalarQuantizer.randomFloatArray;
|
||||||
import static org.apache.lucene.util.TestScalarQuantizer.randomFloats;
|
import static org.apache.lucene.util.TestScalarQuantizer.randomFloats;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.util.Set;
|
||||||
import org.apache.lucene.index.FloatVectorValues;
|
import org.apache.lucene.index.FloatVectorValues;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||||
|
@ -36,7 +37,7 @@ public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
|
||||||
float error = Math.max((100 - confidenceInterval) * 0.01f, 0.01f);
|
float error = Math.max((100 - confidenceInterval) * 0.01f, 0.01f);
|
||||||
FloatVectorValues floatVectorValues = fromFloats(floats);
|
FloatVectorValues floatVectorValues = fromFloats(floats);
|
||||||
ScalarQuantizer scalarQuantizer =
|
ScalarQuantizer scalarQuantizer =
|
||||||
ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval);
|
ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, floats.length);
|
||||||
byte[][] quantized = new byte[floats.length][];
|
byte[][] quantized = new byte[floats.length][];
|
||||||
float[] offsets =
|
float[] offsets =
|
||||||
quantizeVectors(scalarQuantizer, floats, quantized, VectorSimilarityFunction.EUCLIDEAN);
|
quantizeVectors(scalarQuantizer, floats, quantized, VectorSimilarityFunction.EUCLIDEAN);
|
||||||
|
@ -64,9 +65,9 @@ public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
|
||||||
|
|
||||||
for (float confidenceInterval : new float[] {0.9f, 0.95f, 0.99f, (1 - 1f / (dims + 1)), 1f}) {
|
for (float confidenceInterval : new float[] {0.9f, 0.95f, 0.99f, (1 - 1f / (dims + 1)), 1f}) {
|
||||||
float error = Math.max((100 - confidenceInterval) * 0.01f, 0.01f);
|
float error = Math.max((100 - confidenceInterval) * 0.01f, 0.01f);
|
||||||
FloatVectorValues floatVectorValues = fromFloatsNormalized(floats);
|
FloatVectorValues floatVectorValues = fromFloatsNormalized(floats, null);
|
||||||
ScalarQuantizer scalarQuantizer =
|
ScalarQuantizer scalarQuantizer =
|
||||||
ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval);
|
ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, floats.length);
|
||||||
byte[][] quantized = new byte[floats.length][];
|
byte[][] quantized = new byte[floats.length][];
|
||||||
float[] offsets =
|
float[] offsets =
|
||||||
quantizeVectorsNormalized(
|
quantizeVectorsNormalized(
|
||||||
|
@ -100,7 +101,7 @@ public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
|
||||||
float error = Math.max((100 - confidenceInterval) * 0.01f, 0.01f);
|
float error = Math.max((100 - confidenceInterval) * 0.01f, 0.01f);
|
||||||
FloatVectorValues floatVectorValues = fromFloats(floats);
|
FloatVectorValues floatVectorValues = fromFloats(floats);
|
||||||
ScalarQuantizer scalarQuantizer =
|
ScalarQuantizer scalarQuantizer =
|
||||||
ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval);
|
ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, floats.length);
|
||||||
byte[][] quantized = new byte[floats.length][];
|
byte[][] quantized = new byte[floats.length][];
|
||||||
float[] offsets =
|
float[] offsets =
|
||||||
quantizeVectors(scalarQuantizer, floats, quantized, VectorSimilarityFunction.DOT_PRODUCT);
|
quantizeVectors(scalarQuantizer, floats, quantized, VectorSimilarityFunction.DOT_PRODUCT);
|
||||||
|
@ -130,7 +131,7 @@ public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
|
||||||
float error = Math.max((100 - confidenceInterval) * 0.5f, 0.5f);
|
float error = Math.max((100 - confidenceInterval) * 0.5f, 0.5f);
|
||||||
FloatVectorValues floatVectorValues = fromFloats(floats);
|
FloatVectorValues floatVectorValues = fromFloats(floats);
|
||||||
ScalarQuantizer scalarQuantizer =
|
ScalarQuantizer scalarQuantizer =
|
||||||
ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval);
|
ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, floats.length);
|
||||||
byte[][] quantized = new byte[floats.length][];
|
byte[][] quantized = new byte[floats.length][];
|
||||||
float[] offsets =
|
float[] offsets =
|
||||||
quantizeVectors(
|
quantizeVectors(
|
||||||
|
@ -204,8 +205,9 @@ public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
|
||||||
return offsets;
|
return offsets;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static FloatVectorValues fromFloatsNormalized(float[][] floats) {
|
private static FloatVectorValues fromFloatsNormalized(
|
||||||
return new TestScalarQuantizer.TestSimpleFloatVectorValues(floats) {
|
float[][] floats, Set<Integer> deletedVectors) {
|
||||||
|
return new TestScalarQuantizer.TestSimpleFloatVectorValues(floats, deletedVectors) {
|
||||||
@Override
|
@Override
|
||||||
public float[] vectorValue() throws IOException {
|
public float[] vectorValue() throws IOException {
|
||||||
if (curDoc == -1 || curDoc >= floats.length) {
|
if (curDoc == -1 || curDoc >= floats.length) {
|
||||||
|
|
|
@ -17,6 +17,8 @@
|
||||||
package org.apache.lucene.util;
|
package org.apache.lucene.util;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.Set;
|
||||||
import org.apache.lucene.index.FloatVectorValues;
|
import org.apache.lucene.index.FloatVectorValues;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||||
|
@ -30,7 +32,7 @@ public class TestScalarQuantizer extends LuceneTestCase {
|
||||||
|
|
||||||
float[][] floats = randomFloats(numVecs, dims);
|
float[][] floats = randomFloats(numVecs, dims);
|
||||||
FloatVectorValues floatVectorValues = fromFloats(floats);
|
FloatVectorValues floatVectorValues = fromFloats(floats);
|
||||||
ScalarQuantizer scalarQuantizer = ScalarQuantizer.fromVectors(floatVectorValues, 1);
|
ScalarQuantizer scalarQuantizer = ScalarQuantizer.fromVectors(floatVectorValues, 1, numVecs);
|
||||||
float[] dequantized = new float[dims];
|
float[] dequantized = new float[dims];
|
||||||
byte[] quantized = new byte[dims];
|
byte[] quantized = new byte[dims];
|
||||||
byte[] requantized = new byte[dims];
|
byte[] requantized = new byte[dims];
|
||||||
|
@ -71,6 +73,87 @@ public class TestScalarQuantizer extends LuceneTestCase {
|
||||||
assertEquals(1f, upperAndLower[1], 1e-7f);
|
assertEquals(1f, upperAndLower[1], 1e-7f);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testSamplingEdgeCases() throws IOException {
|
||||||
|
int numVecs = 65;
|
||||||
|
int dims = 64;
|
||||||
|
float[][] floats = randomFloats(numVecs, dims);
|
||||||
|
FloatVectorValues floatVectorValues = fromFloats(floats);
|
||||||
|
int[] vectorsToTake = new int[] {0, floats.length - 1};
|
||||||
|
float[] sampled = ScalarQuantizer.sampleVectors(floatVectorValues, vectorsToTake);
|
||||||
|
int i = 0;
|
||||||
|
for (; i < dims; i++) {
|
||||||
|
assertEquals(floats[vectorsToTake[0]][i], sampled[i], 0.0f);
|
||||||
|
}
|
||||||
|
for (; i < dims * 2; i++) {
|
||||||
|
assertEquals(floats[vectorsToTake[1]][i - dims], sampled[i], 0.0f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testVectorSampling() throws IOException {
|
||||||
|
int numVecs = random().nextInt(123) + 5;
|
||||||
|
int dims = 4;
|
||||||
|
float[][] floats = randomFloats(numVecs, dims);
|
||||||
|
FloatVectorValues floatVectorValues = fromFloats(floats);
|
||||||
|
int[] vectorsToTake =
|
||||||
|
ScalarQuantizer.reservoirSampleIndices(numVecs, random().nextInt(numVecs - 1) + 1);
|
||||||
|
int prev = vectorsToTake[0];
|
||||||
|
// ensure sorted & unique
|
||||||
|
for (int i = 1; i < vectorsToTake.length; i++) {
|
||||||
|
assertTrue(vectorsToTake[i] > prev);
|
||||||
|
prev = vectorsToTake[i];
|
||||||
|
}
|
||||||
|
float[] sampled = ScalarQuantizer.sampleVectors(floatVectorValues, vectorsToTake);
|
||||||
|
// ensure we got the right vectors
|
||||||
|
for (int i = 0; i < vectorsToTake.length; i++) {
|
||||||
|
for (int j = 0; j < dims; j++) {
|
||||||
|
assertEquals(floats[vectorsToTake[i]][j], sampled[i * dims + j], 0.0f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testScalarWithSampling() throws IOException {
|
||||||
|
int numVecs = random().nextInt(128) + 5;
|
||||||
|
int dims = 64;
|
||||||
|
float[][] floats = randomFloats(numVecs, dims);
|
||||||
|
// Should not throw
|
||||||
|
{
|
||||||
|
TestSimpleFloatVectorValues floatVectorValues =
|
||||||
|
fromFloatsWithRandomDeletions(floats, random().nextInt(numVecs - 1) + 1);
|
||||||
|
ScalarQuantizer.fromVectors(
|
||||||
|
floatVectorValues,
|
||||||
|
0.99f,
|
||||||
|
floatVectorValues.numLiveVectors,
|
||||||
|
floatVectorValues.numLiveVectors - 1);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
TestSimpleFloatVectorValues floatVectorValues =
|
||||||
|
fromFloatsWithRandomDeletions(floats, random().nextInt(numVecs - 1) + 1);
|
||||||
|
ScalarQuantizer.fromVectors(
|
||||||
|
floatVectorValues,
|
||||||
|
0.99f,
|
||||||
|
floatVectorValues.numLiveVectors,
|
||||||
|
floatVectorValues.numLiveVectors + 1);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
TestSimpleFloatVectorValues floatVectorValues =
|
||||||
|
fromFloatsWithRandomDeletions(floats, random().nextInt(numVecs - 1) + 1);
|
||||||
|
ScalarQuantizer.fromVectors(
|
||||||
|
floatVectorValues,
|
||||||
|
0.99f,
|
||||||
|
floatVectorValues.numLiveVectors,
|
||||||
|
floatVectorValues.numLiveVectors);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
TestSimpleFloatVectorValues floatVectorValues =
|
||||||
|
fromFloatsWithRandomDeletions(floats, random().nextInt(numVecs - 1) + 1);
|
||||||
|
ScalarQuantizer.fromVectors(
|
||||||
|
floatVectorValues,
|
||||||
|
0.99f,
|
||||||
|
floatVectorValues.numLiveVectors,
|
||||||
|
random().nextInt(floatVectorValues.floats.length - 1) + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static void shuffleArray(float[] ar) {
|
static void shuffleArray(float[] ar) {
|
||||||
for (int i = ar.length - 1; i > 0; i--) {
|
for (int i = ar.length - 1; i > 0; i--) {
|
||||||
int index = random().nextInt(i + 1);
|
int index = random().nextInt(i + 1);
|
||||||
|
@ -97,15 +180,29 @@ public class TestScalarQuantizer extends LuceneTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
static FloatVectorValues fromFloats(float[][] floats) {
|
static FloatVectorValues fromFloats(float[][] floats) {
|
||||||
return new TestSimpleFloatVectorValues(floats);
|
return new TestSimpleFloatVectorValues(floats, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
static TestSimpleFloatVectorValues fromFloatsWithRandomDeletions(
|
||||||
|
float[][] floats, int numDeleted) {
|
||||||
|
Set<Integer> deletedVectors = new HashSet<>();
|
||||||
|
for (int i = 0; i < numDeleted; i++) {
|
||||||
|
deletedVectors.add(random().nextInt(floats.length));
|
||||||
|
}
|
||||||
|
return new TestSimpleFloatVectorValues(floats, deletedVectors);
|
||||||
}
|
}
|
||||||
|
|
||||||
static class TestSimpleFloatVectorValues extends FloatVectorValues {
|
static class TestSimpleFloatVectorValues extends FloatVectorValues {
|
||||||
protected final float[][] floats;
|
protected final float[][] floats;
|
||||||
|
protected final Set<Integer> deletedVectors;
|
||||||
|
protected final int numLiveVectors;
|
||||||
protected int curDoc = -1;
|
protected int curDoc = -1;
|
||||||
|
|
||||||
TestSimpleFloatVectorValues(float[][] values) {
|
TestSimpleFloatVectorValues(float[][] values, Set<Integer> deletedVectors) {
|
||||||
this.floats = values;
|
this.floats = values;
|
||||||
|
this.deletedVectors = deletedVectors;
|
||||||
|
this.numLiveVectors =
|
||||||
|
deletedVectors == null ? values.length : values.length - deletedVectors.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -136,14 +233,18 @@ public class TestScalarQuantizer extends LuceneTestCase {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int nextDoc() throws IOException {
|
public int nextDoc() throws IOException {
|
||||||
curDoc++;
|
while (++curDoc < floats.length) {
|
||||||
|
if (deletedVectors == null || !deletedVectors.contains(curDoc)) {
|
||||||
|
return curDoc;
|
||||||
|
}
|
||||||
|
}
|
||||||
return docID();
|
return docID();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int advance(int target) throws IOException {
|
public int advance(int target) throws IOException {
|
||||||
curDoc = target;
|
curDoc = target - 1;
|
||||||
return docID();
|
return nextDoc();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue