diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 181d4e27157..142d20feaae 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -249,6 +249,7 @@ Optimizations * GITHUB#12962: Speedup concurrent multi-segment HNWS graph search (Mayya Sharipova, Tom Veasey) +* GITHUB#13090: Prevent humongous allocations in ScalarQuantizer when building quantiles. (Ben Trent) Bug Fixes --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java index 313b7d9f240..89b2f1ed3ae 100644 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java +++ b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java @@ -68,6 +68,9 @@ import org.apache.lucene.util.Selector; public class ScalarQuantizer { public static final int SCALAR_QUANTIZATION_SAMPLE_SIZE = 25_000; + // 20*dimension provides protection from extreme confidence intervals + // and also prevents humongous allocations + static final int SCRATCH_SIZE = 20; private final float alpha; private final float scale; @@ -206,41 +209,6 @@ public class ScalarQuantizer { return vectorsToTake; } - static float[] sampleVectors(FloatVectorValues floatVectorValues, int[] vectorsToTake) - throws IOException { - int dim = floatVectorValues.dimension(); - float[] values = new float[vectorsToTake.length * dim]; - int copyOffset = 0; - int index = 0; - for (int i : vectorsToTake) { - while (index <= i) { - // We cannot use `advance(docId)` as MergedVectorValues does not support it - floatVectorValues.nextDoc(); - index++; - } - assert floatVectorValues.docID() != NO_MORE_DOCS; - float[] floatVector = floatVectorValues.vectorValue(); - System.arraycopy(floatVector, 0, values, copyOffset, floatVector.length); - copyOffset += dim; - } - return values; - } - - /** - * See {@link #fromVectors(FloatVectorValues, float, int)} for details on how the quantiles are - * calculated. NOTE: If there are deleted vectors in the index, do not use this method, but - * instead use {@link #fromVectors(FloatVectorValues, float, int)}. This is because the - * totalVectorCount is used to account for deleted documents when sampling. - */ - public static ScalarQuantizer fromVectors( - FloatVectorValues floatVectorValues, float confidenceInterval) throws IOException { - return fromVectors( - floatVectorValues, - confidenceInterval, - floatVectorValues.size(), - SCALAR_QUANTIZATION_SAMPLE_SIZE); - } - /** * This will read the float vector values and calculate the quantiles. If the number of float * vectors is less than {@link #SCALAR_QUANTIZATION_SAMPLE_SIZE} then all the values will be read @@ -269,6 +237,7 @@ public class ScalarQuantizer { int quantizationSampleSize) throws IOException { assert 0.9f <= confidenceInterval && confidenceInterval <= 1f; + assert quantizationSampleSize > SCRATCH_SIZE; if (totalVectorCount == 0) { return new ScalarQuantizer(0f, 0f, confidenceInterval); } @@ -283,24 +252,60 @@ public class ScalarQuantizer { } return new ScalarQuantizer(min, max, confidenceInterval); } - int dim = floatVectorValues.dimension(); + final float[] quantileGatheringScratch = + new float[floatVectorValues.dimension() * Math.min(SCRATCH_SIZE, totalVectorCount)]; + int count = 0; + double upperSum = 0; + double lowerSum = 0; if (totalVectorCount <= quantizationSampleSize) { - int copyOffset = 0; - float[] values = new float[totalVectorCount * dim]; + int scratchSize = Math.min(SCRATCH_SIZE, totalVectorCount); + int i = 0; while (floatVectorValues.nextDoc() != NO_MORE_DOCS) { - float[] floatVector = floatVectorValues.vectorValue(); - System.arraycopy(floatVector, 0, values, copyOffset, floatVector.length); - copyOffset += dim; + float[] vectorValue = floatVectorValues.vectorValue(); + System.arraycopy( + vectorValue, 0, quantileGatheringScratch, i * vectorValue.length, vectorValue.length); + i++; + if (i == scratchSize) { + float[] upperAndLower = + getUpperAndLowerQuantile(quantileGatheringScratch, confidenceInterval); + upperSum += upperAndLower[1]; + lowerSum += upperAndLower[0]; + i = 0; + count++; + } } - float[] upperAndLower = getUpperAndLowerQuantile(values, confidenceInterval); - return new ScalarQuantizer(upperAndLower[0], upperAndLower[1], confidenceInterval); + // Note, we purposefully don't use the rest of the scratch state if we have fewer than + // `SCRATCH_SIZE` vectors, mainly because if we are sampling so few vectors then we don't + // want to be adversely affected by the extreme confidence intervals over small sample sizes + return new ScalarQuantizer( + (float) lowerSum / count, (float) upperSum / count, confidenceInterval); } - int numFloatVecs = totalVectorCount; // Reservoir sample the vector ordinals we want to read - int[] vectorsToTake = reservoirSampleIndices(numFloatVecs, quantizationSampleSize); - float[] values = sampleVectors(floatVectorValues, vectorsToTake); - float[] upperAndLower = getUpperAndLowerQuantile(values, confidenceInterval); - return new ScalarQuantizer(upperAndLower[0], upperAndLower[1], confidenceInterval); + int[] vectorsToTake = reservoirSampleIndices(totalVectorCount, quantizationSampleSize); + int index = 0; + int idx = 0; + for (int i : vectorsToTake) { + while (index <= i) { + // We cannot use `advance(docId)` as MergedVectorValues does not support it + floatVectorValues.nextDoc(); + index++; + } + assert floatVectorValues.docID() != NO_MORE_DOCS; + float[] vectorValue = floatVectorValues.vectorValue(); + System.arraycopy( + vectorValue, 0, quantileGatheringScratch, idx * vectorValue.length, vectorValue.length); + idx++; + if (idx == SCRATCH_SIZE) { + float[] upperAndLower = + getUpperAndLowerQuantile(quantileGatheringScratch, confidenceInterval); + upperSum += upperAndLower[1]; + lowerSum += upperAndLower[0]; + count++; + idx = 0; + } + } + return new ScalarQuantizer( + (float) lowerSum / count, (float) upperSum / count, confidenceInterval); } /** diff --git a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java index cba70e747f3..926b985f538 100644 --- a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java +++ b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java @@ -16,6 +16,8 @@ */ package org.apache.lucene.util.quantization; +import static org.apache.lucene.util.quantization.ScalarQuantizer.SCRATCH_SIZE; + import java.io.IOException; import java.util.HashSet; import java.util.Set; @@ -73,44 +75,6 @@ public class TestScalarQuantizer extends LuceneTestCase { assertEquals(1f, upperAndLower[1], 1e-7f); } - public void testSamplingEdgeCases() throws IOException { - int numVecs = 65; - int dims = 64; - float[][] floats = randomFloats(numVecs, dims); - FloatVectorValues floatVectorValues = fromFloats(floats); - int[] vectorsToTake = new int[] {0, floats.length - 1}; - float[] sampled = ScalarQuantizer.sampleVectors(floatVectorValues, vectorsToTake); - int i = 0; - for (; i < dims; i++) { - assertEquals(floats[vectorsToTake[0]][i], sampled[i], 0.0f); - } - for (; i < dims * 2; i++) { - assertEquals(floats[vectorsToTake[1]][i - dims], sampled[i], 0.0f); - } - } - - public void testVectorSampling() throws IOException { - int numVecs = random().nextInt(123) + 5; - int dims = 4; - float[][] floats = randomFloats(numVecs, dims); - FloatVectorValues floatVectorValues = fromFloats(floats); - int[] vectorsToTake = - ScalarQuantizer.reservoirSampleIndices(numVecs, random().nextInt(numVecs - 1) + 1); - int prev = vectorsToTake[0]; - // ensure sorted & unique - for (int i = 1; i < vectorsToTake.length; i++) { - assertTrue(vectorsToTake[i] > prev); - prev = vectorsToTake[i]; - } - float[] sampled = ScalarQuantizer.sampleVectors(floatVectorValues, vectorsToTake); - // ensure we got the right vectors - for (int i = 0; i < vectorsToTake.length; i++) { - for (int j = 0; j < dims; j++) { - assertEquals(floats[vectorsToTake[i]][j], sampled[i * dims + j], 0.0f); - } - } - } - public void testScalarWithSampling() throws IOException { int numVecs = random().nextInt(128) + 5; int dims = 64; @@ -123,7 +87,7 @@ public class TestScalarQuantizer extends LuceneTestCase { floatVectorValues, 0.99f, floatVectorValues.numLiveVectors, - floatVectorValues.numLiveVectors - 1); + Math.max(floatVectorValues.numLiveVectors - 1, SCRATCH_SIZE + 1)); } { TestSimpleFloatVectorValues floatVectorValues = @@ -132,7 +96,7 @@ public class TestScalarQuantizer extends LuceneTestCase { floatVectorValues, 0.99f, floatVectorValues.numLiveVectors, - floatVectorValues.numLiveVectors + 1); + Math.max(floatVectorValues.numLiveVectors - 1, SCRATCH_SIZE + 1)); } { TestSimpleFloatVectorValues floatVectorValues = @@ -141,7 +105,7 @@ public class TestScalarQuantizer extends LuceneTestCase { floatVectorValues, 0.99f, floatVectorValues.numLiveVectors, - floatVectorValues.numLiveVectors); + Math.max(floatVectorValues.numLiveVectors - 1, SCRATCH_SIZE + 1)); } { TestSimpleFloatVectorValues floatVectorValues = @@ -150,7 +114,7 @@ public class TestScalarQuantizer extends LuceneTestCase { floatVectorValues, 0.99f, floatVectorValues.numLiveVectors, - random().nextInt(floatVectorValues.floats.length - 1) + 1); + Math.max(random().nextInt(floatVectorValues.floats.length - 1) + 1, SCRATCH_SIZE + 1)); } }