Simple rename of unreleased quantization parameter (#12811)

This commit is contained in:
Benjamin Trent 2023-11-15 15:00:12 -05:00 committed by GitHub
parent 05a336ea69
commit a26a80c89c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 165 additions and 117 deletions

View File

@ -84,8 +84,8 @@ public final class Lucene99HnswScalarQuantizedVectorsFormat extends KnnVectorsFo
* @param beamWidth the size of the queue maintained during graph construction. * @param beamWidth the size of the queue maintained during graph construction.
* @param numMergeWorkers number of workers (threads) that will be used when doing merge. If * @param numMergeWorkers number of workers (threads) that will be used when doing merge. If
* larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec * larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec
* @param configuredQuantile the quantile for scalar quantizing the vectors, when `null` it is * @param confidenceInterval the confidenceInterval for scalar quantizing the vectors, when `null`
* calculated based on the vector field dimensions. * it is calculated based on the vector field dimensions.
* @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are * @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are
* generated by this format to do the merge * generated by this format to do the merge
*/ */
@ -93,7 +93,7 @@ public final class Lucene99HnswScalarQuantizedVectorsFormat extends KnnVectorsFo
int maxConn, int maxConn,
int beamWidth, int beamWidth,
int numMergeWorkers, int numMergeWorkers,
Float configuredQuantile, Float confidenceInterval,
ExecutorService mergeExec) { ExecutorService mergeExec) {
super("Lucene99HnswScalarQuantizedVectorsFormat"); super("Lucene99HnswScalarQuantizedVectorsFormat");
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) { if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
@ -122,7 +122,7 @@ public final class Lucene99HnswScalarQuantizedVectorsFormat extends KnnVectorsFo
} }
this.numMergeWorkers = numMergeWorkers; this.numMergeWorkers = numMergeWorkers;
this.mergeExec = mergeExec; this.mergeExec = mergeExec;
this.flatVectorsFormat = new Lucene99ScalarQuantizedVectorsFormat(configuredQuantile); this.flatVectorsFormat = new Lucene99ScalarQuantizedVectorsFormat(confidenceInterval);
} }
@Override @Override

View File

@ -43,17 +43,17 @@ public final class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsForma
private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat(); private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat();
/** The minimum quantile */ /** The minimum confidence interval */
private static final float MINIMUM_QUANTILE = 0.9f; private static final float MINIMUM_CONFIDENCE_INTERVAL = 0.9f;
/** The maximum quantile */ /** The maximum confidence interval */
private static final float MAXIMUM_QUANTILE = 1f; private static final float MAXIMUM_CONFIDENCE_INTERVAL = 1f;
/** /**
* Controls the quantile used to scalar quantize the vectors the default quantile is calculated as * Controls the confidence interval used to scalar quantize the vectors the default value is
* `1-1/(vector_dimensions + 1)` * calculated as `1-1/(vector_dimensions + 1)`
*/ */
final Float quantile; final Float confidenceInterval;
/** Constructs a format using default graph construction parameters */ /** Constructs a format using default graph construction parameters */
public Lucene99ScalarQuantizedVectorsFormat() { public Lucene99ScalarQuantizedVectorsFormat() {
@ -63,24 +63,26 @@ public final class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsForma
/** /**
* Constructs a format using the given graph construction parameters. * Constructs a format using the given graph construction parameters.
* *
* @param quantile the quantile for scalar quantizing the vectors, when `null` it is calculated * @param confidenceInterval the confidenceInterval for scalar quantizing the vectors, when `null`
* based on the vector field dimensions. * it is calculated based on the vector field dimensions.
*/ */
public Lucene99ScalarQuantizedVectorsFormat(Float quantile) { public Lucene99ScalarQuantizedVectorsFormat(Float confidenceInterval) {
if (quantile != null && (quantile < MINIMUM_QUANTILE || quantile > MAXIMUM_QUANTILE)) { if (confidenceInterval != null
&& (confidenceInterval < MINIMUM_CONFIDENCE_INTERVAL
|| confidenceInterval > MAXIMUM_CONFIDENCE_INTERVAL)) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"quantile must be between " "confidenceInterval must be between "
+ MINIMUM_QUANTILE + MINIMUM_CONFIDENCE_INTERVAL
+ " and " + " and "
+ MAXIMUM_QUANTILE + MAXIMUM_CONFIDENCE_INTERVAL
+ "; quantile=" + "; confidenceInterval="
+ quantile); + confidenceInterval);
} }
this.quantile = quantile; this.confidenceInterval = confidenceInterval;
} }
static float calculateDefaultQuantile(int vectorDimension) { static float calculateDefaultConfidenceInterval(int vectorDimension) {
return Math.max(MINIMUM_QUANTILE, 1f - (1f / (vectorDimension + 1))); return Math.max(MINIMUM_CONFIDENCE_INTERVAL, 1f - (1f / (vectorDimension + 1)));
} }
@Override @Override
@ -88,8 +90,8 @@ public final class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsForma
return NAME return NAME
+ "(name=" + "(name="
+ NAME + NAME
+ ", quantile=" + ", confidenceInterval="
+ quantile + confidenceInterval
+ ", rawVectorFormat=" + ", rawVectorFormat="
+ rawVectorFormat + rawVectorFormat
+ ")"; + ")";
@ -98,7 +100,7 @@ public final class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsForma
@Override @Override
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99ScalarQuantizedVectorsWriter( return new Lucene99ScalarQuantizedVectorsWriter(
state, quantile, rawVectorFormat.fieldsWriter(state)); state, confidenceInterval, rawVectorFormat.fieldsWriter(state));
} }
@Override @Override

View File

@ -303,10 +303,10 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
dimension = input.readVInt(); dimension = input.readVInt();
size = input.readInt(); size = input.readInt();
if (size > 0) { if (size > 0) {
float configuredQuantile = Float.intBitsToFloat(input.readInt()); float confidenceInterval = Float.intBitsToFloat(input.readInt());
float minQuantile = Float.intBitsToFloat(input.readInt()); float minQuantile = Float.intBitsToFloat(input.readInt());
float maxQuantile = Float.intBitsToFloat(input.readInt()); float maxQuantile = Float.intBitsToFloat(input.readInt());
scalarQuantizer = new ScalarQuantizer(minQuantile, maxQuantile, configuredQuantile); scalarQuantizer = new ScalarQuantizer(minQuantile, maxQuantile, confidenceInterval);
} else { } else {
scalarQuantizer = null; scalarQuantizer = null;
} }

View File

@ -19,7 +19,7 @@ package org.apache.lucene.codecs.lucene99;
import static org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; import static org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT;
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_COMPONENT; import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_COMPONENT;
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile; import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.calculateDefaultConfidenceInterval;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance; import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance;
@ -91,14 +91,14 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
private final List<FieldWriter> fields = new ArrayList<>(); private final List<FieldWriter> fields = new ArrayList<>();
private final IndexOutput meta, quantizedVectorData; private final IndexOutput meta, quantizedVectorData;
private final Float quantile; private final Float confidenceInterval;
private final FlatVectorsWriter rawVectorDelegate; private final FlatVectorsWriter rawVectorDelegate;
private boolean finished; private boolean finished;
Lucene99ScalarQuantizedVectorsWriter( Lucene99ScalarQuantizedVectorsWriter(
SegmentWriteState state, Float quantile, FlatVectorsWriter rawVectorDelegate) SegmentWriteState state, Float confidenceInterval, FlatVectorsWriter rawVectorDelegate)
throws IOException { throws IOException {
this.quantile = quantile; this.confidenceInterval = confidenceInterval;
segmentWriteState = state; segmentWriteState = state;
String metaFileName = String metaFileName =
IndexFileNames.segmentFileName( IndexFileNames.segmentFileName(
@ -142,12 +142,12 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
public FlatFieldVectorsWriter<?> addField( public FlatFieldVectorsWriter<?> addField(
FieldInfo fieldInfo, KnnFieldVectorsWriter<?> indexWriter) throws IOException { FieldInfo fieldInfo, KnnFieldVectorsWriter<?> indexWriter) throws IOException {
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
float quantile = float confidenceInterval =
this.quantile == null this.confidenceInterval == null
? calculateDefaultQuantile(fieldInfo.getVectorDimension()) ? calculateDefaultConfidenceInterval(fieldInfo.getVectorDimension())
: this.quantile; : this.confidenceInterval;
FieldWriter quantizedWriter = FieldWriter quantizedWriter =
new FieldWriter(quantile, fieldInfo, segmentWriteState.infoStream, indexWriter); new FieldWriter(confidenceInterval, fieldInfo, segmentWriteState.infoStream, indexWriter);
fields.add(quantizedWriter); fields.add(quantizedWriter);
indexWriter = quantizedWriter; indexWriter = quantizedWriter;
} }
@ -169,16 +169,16 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
DocsWithFieldSet docsWithField = DocsWithFieldSet docsWithField =
writeQuantizedVectorData(quantizedVectorData, byteVectorValues); writeQuantizedVectorData(quantizedVectorData, byteVectorValues);
long vectorDataLength = quantizedVectorData.getFilePointer() - vectorDataOffset; long vectorDataLength = quantizedVectorData.getFilePointer() - vectorDataOffset;
float quantile = float confidenceInterval =
this.quantile == null this.confidenceInterval == null
? calculateDefaultQuantile(fieldInfo.getVectorDimension()) ? calculateDefaultConfidenceInterval(fieldInfo.getVectorDimension())
: this.quantile; : this.confidenceInterval;
writeMeta( writeMeta(
fieldInfo, fieldInfo,
segmentWriteState.segmentInfo.maxDoc(), segmentWriteState.segmentInfo.maxDoc(),
vectorDataOffset, vectorDataOffset,
vectorDataLength, vectorDataLength,
quantile, confidenceInterval,
mergedQuantizationState.getLowerQuantile(), mergedQuantizationState.getLowerQuantile(),
mergedQuantizationState.getUpperQuantile(), mergedQuantizationState.getUpperQuantile(),
docsWithField); docsWithField);
@ -251,7 +251,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
maxDoc, maxDoc,
vectorDataOffset, vectorDataOffset,
vectorDataLength, vectorDataLength,
quantile, confidenceInterval,
fieldData.minQuantile, fieldData.minQuantile,
fieldData.maxQuantile, fieldData.maxQuantile,
fieldData.docsWithField); fieldData.docsWithField);
@ -262,7 +262,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
int maxDoc, int maxDoc,
long vectorDataOffset, long vectorDataOffset,
long vectorDataLength, long vectorDataLength,
Float configuredQuantizationQuantile, Float confidenceInterval,
Float lowerQuantile, Float lowerQuantile,
Float upperQuantile, Float upperQuantile,
DocsWithFieldSet docsWithField) DocsWithFieldSet docsWithField)
@ -279,9 +279,9 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
assert Float.isFinite(lowerQuantile) && Float.isFinite(upperQuantile); assert Float.isFinite(lowerQuantile) && Float.isFinite(upperQuantile);
meta.writeInt( meta.writeInt(
Float.floatToIntBits( Float.floatToIntBits(
configuredQuantizationQuantile != null confidenceInterval != null
? configuredQuantizationQuantile ? confidenceInterval
: calculateDefaultQuantile(field.getVectorDimension()))); : calculateDefaultConfidenceInterval(field.getVectorDimension())));
meta.writeInt(Float.floatToIntBits(lowerQuantile)); meta.writeInt(Float.floatToIntBits(lowerQuantile));
meta.writeInt(Float.floatToIntBits(upperQuantile)); meta.writeInt(Float.floatToIntBits(upperQuantile));
} }
@ -344,7 +344,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
maxDoc, maxDoc,
vectorDataOffset, vectorDataOffset,
quantizedVectorLength, quantizedVectorLength,
quantile, confidenceInterval,
fieldData.minQuantile, fieldData.minQuantile,
fieldData.maxQuantile, fieldData.maxQuantile,
newDocsWithField); newDocsWithField);
@ -374,11 +374,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
private ScalarQuantizer mergeQuantiles(FieldInfo fieldInfo, MergeState mergeState) private ScalarQuantizer mergeQuantiles(FieldInfo fieldInfo, MergeState mergeState)
throws IOException { throws IOException {
assert fieldInfo.getVectorEncoding() == VectorEncoding.FLOAT32; assert fieldInfo.getVectorEncoding() == VectorEncoding.FLOAT32;
float quantile = float confidenceInterval =
this.quantile == null this.confidenceInterval == null
? calculateDefaultQuantile(fieldInfo.getVectorDimension()) ? calculateDefaultConfidenceInterval(fieldInfo.getVectorDimension())
: this.quantile; : this.confidenceInterval;
return mergeAndRecalculateQuantiles(mergeState, fieldInfo, quantile); return mergeAndRecalculateQuantiles(mergeState, fieldInfo, confidenceInterval);
} }
private ScalarQuantizedCloseableRandomVectorScorerSupplier mergeOneFieldToIndex( private ScalarQuantizedCloseableRandomVectorScorerSupplier mergeOneFieldToIndex(
@ -408,16 +408,16 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
quantizationDataInput, quantizationDataInput.length() - CodecUtil.footerLength()); quantizationDataInput, quantizationDataInput.length() - CodecUtil.footerLength());
long vectorDataLength = quantizedVectorData.getFilePointer() - vectorDataOffset; long vectorDataLength = quantizedVectorData.getFilePointer() - vectorDataOffset;
CodecUtil.retrieveChecksum(quantizationDataInput); CodecUtil.retrieveChecksum(quantizationDataInput);
float quantile = float confidenceInterval =
this.quantile == null this.confidenceInterval == null
? calculateDefaultQuantile(fieldInfo.getVectorDimension()) ? calculateDefaultConfidenceInterval(fieldInfo.getVectorDimension())
: this.quantile; : this.confidenceInterval;
writeMeta( writeMeta(
fieldInfo, fieldInfo,
segmentWriteState.segmentInfo.maxDoc(), segmentWriteState.segmentInfo.maxDoc(),
vectorDataOffset, vectorDataOffset,
vectorDataLength, vectorDataLength,
quantile, confidenceInterval,
mergedQuantizationState.getLowerQuantile(), mergedQuantizationState.getLowerQuantile(),
mergedQuantizationState.getUpperQuantile(), mergedQuantizationState.getUpperQuantile(),
docsWithField); docsWithField);
@ -446,7 +446,9 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
} }
static ScalarQuantizer mergeQuantiles( static ScalarQuantizer mergeQuantiles(
List<ScalarQuantizer> quantizationStates, List<Integer> segmentSizes, float quantile) { List<ScalarQuantizer> quantizationStates,
List<Integer> segmentSizes,
float confidenceInterval) {
assert quantizationStates.size() == segmentSizes.size(); assert quantizationStates.size() == segmentSizes.size();
if (quantizationStates.isEmpty()) { if (quantizationStates.isEmpty()) {
return null; return null;
@ -464,7 +466,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
} }
lowerQuantile /= totalCount; lowerQuantile /= totalCount;
upperQuantile /= totalCount; upperQuantile /= totalCount;
return new ScalarQuantizer(lowerQuantile, upperQuantile, quantile); return new ScalarQuantizer(lowerQuantile, upperQuantile, confidenceInterval);
} }
/** /**
@ -521,7 +523,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
} }
static ScalarQuantizer mergeAndRecalculateQuantiles( static ScalarQuantizer mergeAndRecalculateQuantiles(
MergeState mergeState, FieldInfo fieldInfo, float quantile) throws IOException { MergeState mergeState, FieldInfo fieldInfo, float confidenceInterval) throws IOException {
List<ScalarQuantizer> quantizationStates = new ArrayList<>(mergeState.liveDocs.length); List<ScalarQuantizer> quantizationStates = new ArrayList<>(mergeState.liveDocs.length);
List<Integer> segmentSizes = new ArrayList<>(mergeState.liveDocs.length); List<Integer> segmentSizes = new ArrayList<>(mergeState.liveDocs.length);
for (int i = 0; i < mergeState.liveDocs.length; i++) { for (int i = 0; i < mergeState.liveDocs.length; i++) {
@ -536,7 +538,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
segmentSizes.add(fvv.size()); segmentSizes.add(fvv.size());
} }
} }
ScalarQuantizer mergedQuantiles = mergeQuantiles(quantizationStates, segmentSizes, quantile); ScalarQuantizer mergedQuantiles =
mergeQuantiles(quantizationStates, segmentSizes, confidenceInterval);
// Segments no providing quantization state indicates that their quantiles were never // Segments no providing quantization state indicates that their quantiles were never
// calculated. // calculated.
// To be safe, we should always recalculate given a sample set over all the float vectors in the // To be safe, we should always recalculate given a sample set over all the float vectors in the
@ -545,7 +548,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
if (mergedQuantiles == null || shouldRecomputeQuantiles(mergedQuantiles, quantizationStates)) { if (mergedQuantiles == null || shouldRecomputeQuantiles(mergedQuantiles, quantizationStates)) {
FloatVectorValues vectorValues = FloatVectorValues vectorValues =
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
mergedQuantiles = ScalarQuantizer.fromVectors(vectorValues, quantile); mergedQuantiles = ScalarQuantizer.fromVectors(vectorValues, confidenceInterval);
} }
return mergedQuantiles; return mergedQuantiles;
} }
@ -599,7 +602,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
private static final long SHALLOW_SIZE = shallowSizeOfInstance(FieldWriter.class); private static final long SHALLOW_SIZE = shallowSizeOfInstance(FieldWriter.class);
private final List<float[]> floatVectors; private final List<float[]> floatVectors;
private final FieldInfo fieldInfo; private final FieldInfo fieldInfo;
private final float quantile; private final float confidenceInterval;
private final InfoStream infoStream; private final InfoStream infoStream;
private final boolean normalize; private final boolean normalize;
private float minQuantile = Float.POSITIVE_INFINITY; private float minQuantile = Float.POSITIVE_INFINITY;
@ -609,12 +612,12 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
FieldWriter( FieldWriter(
float quantile, float confidenceInterval,
FieldInfo fieldInfo, FieldInfo fieldInfo,
InfoStream infoStream, InfoStream infoStream,
KnnFieldVectorsWriter<?> indexWriter) { KnnFieldVectorsWriter<?> indexWriter) {
super((KnnFieldVectorsWriter<float[]>) indexWriter); super((KnnFieldVectorsWriter<float[]>) indexWriter);
this.quantile = quantile; this.confidenceInterval = confidenceInterval;
this.fieldInfo = fieldInfo; this.fieldInfo = fieldInfo;
this.normalize = fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE; this.normalize = fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE;
this.floatVectors = new ArrayList<>(); this.floatVectors = new ArrayList<>();
@ -635,15 +638,15 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
new FloatVectorWrapper( new FloatVectorWrapper(
floatVectors, floatVectors,
fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE), fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE),
quantile); confidenceInterval);
minQuantile = quantizer.getLowerQuantile(); minQuantile = quantizer.getLowerQuantile();
maxQuantile = quantizer.getUpperQuantile(); maxQuantile = quantizer.getUpperQuantile();
if (infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) { if (infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) {
infoStream.message( infoStream.message(
QUANTIZED_VECTOR_COMPONENT, QUANTIZED_VECTOR_COMPONENT,
"quantized field=" "quantized field="
+ " quantile=" + " confidenceInterval="
+ quantile + confidenceInterval
+ " minQuantile=" + " minQuantile="
+ minQuantile + minQuantile
+ " maxQuantile=" + " maxQuantile="
@ -654,7 +657,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
ScalarQuantizer createQuantizer() { ScalarQuantizer createQuantizer() {
assert finished; assert finished;
return new ScalarQuantizer(minQuantile, maxQuantile, quantile); return new ScalarQuantizer(minQuantile, maxQuantile, confidenceInterval);
} }
@Override @Override

View File

@ -28,15 +28,15 @@ import org.apache.lucene.index.VectorSimilarityFunction;
/** /**
* Will scalar quantize float vectors into `int8` byte values. This is a lossy transformation. * Will scalar quantize float vectors into `int8` byte values. This is a lossy transformation.
* Scalar quantization works by first calculating the quantiles of the float vector values. The * Scalar quantization works by first calculating the quantiles of the float vector values. The
* quantiles are calculated using the configured quantile/confidence interval. The [minQuantile, * quantiles are calculated using the configured confidence interval. The [minQuantile, maxQuantile]
* maxQuantile] are then used to scale the values into the range [0, 127] and bucketed into the * are then used to scale the values into the range [0, 127] and bucketed into the nearest byte
* nearest byte values. * values.
* *
* <h2>How Scalar Quantization Works</h2> * <h2>How Scalar Quantization Works</h2>
* *
* <p>The basic mathematical equations behind this are fairly straight forward. Given a float vector * <p>The basic mathematical equations behind this are fairly straight forward and based on min/max
* `v` and a quantile `q` we can calculate the quantiles of the vector values [minQuantile, * normalization. Given a float vector `v` and a confidenceInterval `q` we can calculate the
* maxQuantile]. * quantiles of the vector values [minQuantile, maxQuantile].
* *
* <pre class="prettyprint"> * <pre class="prettyprint">
* byte = (float - minQuantile) * 127/(maxQuantile - minQuantile) * byte = (float - minQuantile) * 127/(maxQuantile - minQuantile)
@ -69,21 +69,20 @@ public class ScalarQuantizer {
private final float alpha; private final float alpha;
private final float scale; private final float scale;
private final float minQuantile, maxQuantile, configuredQuantile; private final float minQuantile, maxQuantile, confidenceInterval;
/** /**
* @param minQuantile the lower quantile of the distribution * @param minQuantile the lower quantile of the distribution
* @param maxQuantile the upper quantile of the distribution * @param maxQuantile the upper quantile of the distribution
* @param configuredQuantile The configured quantile/confidence interval used to calculate the * @param confidenceInterval The configured confidence interval used to calculate the quantiles.
* quantiles.
*/ */
public ScalarQuantizer(float minQuantile, float maxQuantile, float configuredQuantile) { public ScalarQuantizer(float minQuantile, float maxQuantile, float confidenceInterval) {
assert maxQuantile >= minQuantile; assert maxQuantile >= minQuantile;
this.minQuantile = minQuantile; this.minQuantile = minQuantile;
this.maxQuantile = maxQuantile; this.maxQuantile = maxQuantile;
this.scale = 127f / (maxQuantile - minQuantile); this.scale = 127f / (maxQuantile - minQuantile);
this.alpha = (maxQuantile - minQuantile) / 127f; this.alpha = (maxQuantile - minQuantile) / 127f;
this.configuredQuantile = configuredQuantile; this.confidenceInterval = confidenceInterval;
} }
/** /**
@ -171,8 +170,8 @@ public class ScalarQuantizer {
return maxQuantile; return maxQuantile;
} }
public float getConfiguredQuantile() { public float getConfidenceInterval() {
return configuredQuantile; return confidenceInterval;
} }
public float getConstantMultiplier() { public float getConstantMultiplier() {
@ -186,8 +185,8 @@ public class ScalarQuantizer {
+ minQuantile + minQuantile
+ ", maxQuantile=" + ", maxQuantile="
+ maxQuantile + maxQuantile
+ ", configuredQuantile=" + ", confidenceInterval="
+ configuredQuantile + confidenceInterval
+ '}'; + '}';
} }
@ -201,17 +200,17 @@ public class ScalarQuantizer {
* #SCALAR_QUANTIZATION_SAMPLE_SIZE} will be read and the quantiles calculated. * #SCALAR_QUANTIZATION_SAMPLE_SIZE} will be read and the quantiles calculated.
* *
* @param floatVectorValues the float vector values from which to calculate the quantiles * @param floatVectorValues the float vector values from which to calculate the quantiles
* @param quantile the quantile/confidence interval used to calculate the quantiles * @param confidenceInterval the confidence interval used to calculate the quantiles
* @return A new {@link ScalarQuantizer} instance * @return A new {@link ScalarQuantizer} instance
* @throws IOException if there is an error reading the float vector values * @throws IOException if there is an error reading the float vector values
*/ */
public static ScalarQuantizer fromVectors(FloatVectorValues floatVectorValues, float quantile) public static ScalarQuantizer fromVectors(
throws IOException { FloatVectorValues floatVectorValues, float confidenceInterval) throws IOException {
assert 0.9f <= quantile && quantile <= 1f; assert 0.9f <= confidenceInterval && confidenceInterval <= 1f;
if (floatVectorValues.size() == 0) { if (floatVectorValues.size() == 0) {
return new ScalarQuantizer(0f, 0f, quantile); return new ScalarQuantizer(0f, 0f, confidenceInterval);
} }
if (quantile == 1f) { if (confidenceInterval == 1f) {
float min = Float.POSITIVE_INFINITY; float min = Float.POSITIVE_INFINITY;
float max = Float.NEGATIVE_INFINITY; float max = Float.NEGATIVE_INFINITY;
while (floatVectorValues.nextDoc() != NO_MORE_DOCS) { while (floatVectorValues.nextDoc() != NO_MORE_DOCS) {
@ -220,7 +219,7 @@ public class ScalarQuantizer {
max = Math.max(max, v); max = Math.max(max, v);
} }
} }
return new ScalarQuantizer(min, max, quantile); return new ScalarQuantizer(min, max, confidenceInterval);
} }
int dim = floatVectorValues.dimension(); int dim = floatVectorValues.dimension();
if (floatVectorValues.size() < SCALAR_QUANTIZATION_SAMPLE_SIZE) { if (floatVectorValues.size() < SCALAR_QUANTIZATION_SAMPLE_SIZE) {
@ -231,8 +230,8 @@ public class ScalarQuantizer {
System.arraycopy(floatVector, 0, values, copyOffset, floatVector.length); System.arraycopy(floatVector, 0, values, copyOffset, floatVector.length);
copyOffset += dim; copyOffset += dim;
} }
float[] upperAndLower = getUpperAndLowerQuantile(values, quantile); float[] upperAndLower = getUpperAndLowerQuantile(values, confidenceInterval);
return new ScalarQuantizer(upperAndLower[0], upperAndLower[1], quantile); return new ScalarQuantizer(upperAndLower[0], upperAndLower[1], confidenceInterval);
} }
int numFloatVecs = floatVectorValues.size(); int numFloatVecs = floatVectorValues.size();
// Reservoir sample the vector ordinals we want to read // Reservoir sample the vector ordinals we want to read
@ -258,22 +257,23 @@ public class ScalarQuantizer {
System.arraycopy(floatVector, 0, values, copyOffset, floatVector.length); System.arraycopy(floatVector, 0, values, copyOffset, floatVector.length);
copyOffset += dim; copyOffset += dim;
} }
float[] upperAndLower = getUpperAndLowerQuantile(values, quantile); float[] upperAndLower = getUpperAndLowerQuantile(values, confidenceInterval);
return new ScalarQuantizer(upperAndLower[0], upperAndLower[1], quantile); return new ScalarQuantizer(upperAndLower[0], upperAndLower[1], confidenceInterval);
} }
/** /**
* Takes an array of floats, sorted or not, and returns a minimum and maximum value. These values * Takes an array of floats, sorted or not, and returns a minimum and maximum value. These values
* are such that they reside on the `(1 - quantile)/2` and `quantile/2` percentiles. Example: * are such that they reside on the `(1 - confidenceInterval)/2` and `confidenceInterval/2`
* providing floats `[0..100]` and asking for `90` quantiles will return `5` and `95`. * percentiles. Example: providing floats `[0..100]` and asking for `90` quantiles will return `5`
* and `95`.
* *
* @param arr array of floats * @param arr array of floats
* @param quantileFloat the configured quantile * @param confidenceInterval the configured confidence interval
* @return lower and upper quantile values * @return lower and upper quantile values
*/ */
static float[] getUpperAndLowerQuantile(float[] arr, float quantileFloat) { static float[] getUpperAndLowerQuantile(float[] arr, float confidenceInterval) {
assert 0.9f <= quantileFloat && quantileFloat <= 1f; assert 0.9f <= confidenceInterval && confidenceInterval <= 1f;
int selectorIndex = (int) (arr.length * (1f - quantileFloat) / 2f + 0.5f); int selectorIndex = (int) (arr.length * (1f - confidenceInterval) / 2f + 0.5f);
if (selectorIndex > 0) { if (selectorIndex > 0) {
Selector selector = new FloatSelector(arr); Selector selector = new FloatSelector(arr);
selector.select(0, arr.length, arr.length - selectorIndex); selector.select(0, arr.length, arr.length - selectorIndex);

View File

@ -37,6 +37,7 @@ import org.apache.lucene.index.NoMergePolicy;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory; import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
import org.apache.lucene.util.SameThreadExecutorService;
import org.apache.lucene.util.ScalarQuantizer; import org.apache.lucene.util.ScalarQuantizer;
import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.VectorUtil;
@ -64,11 +65,12 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
for (int i = 0; i < numVectors; i++) { for (int i = 0; i < numVectors; i++) {
vectors.add(randomVector(dim)); vectors.add(randomVector(dim));
} }
float quantile = Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile(dim); float confidenceInterval =
Lucene99ScalarQuantizedVectorsFormat.calculateDefaultConfidenceInterval(dim);
ScalarQuantizer scalarQuantizer = ScalarQuantizer scalarQuantizer =
ScalarQuantizer.fromVectors( ScalarQuantizer.fromVectors(
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, normalize), new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, normalize),
quantile); confidenceInterval);
float[] expectedCorrections = new float[numVectors]; float[] expectedCorrections = new float[numVectors];
byte[][] expectedVectors = new byte[numVectors][]; byte[][] expectedVectors = new byte[numVectors][];
for (int i = 0; i < numVectors; i++) { for (int i = 0; i < numVectors; i++) {
@ -148,7 +150,38 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
} }
}; };
String expectedString = String expectedString =
"Lucene99HnswScalarQuantizedVectorsFormat(name=Lucene99HnswScalarQuantizedVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, quantile=0.9, rawVectorFormat=Lucene99FlatVectorsFormat()))"; "Lucene99HnswScalarQuantizedVectorsFormat(name=Lucene99HnswScalarQuantizedVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, confidenceInterval=0.9, rawVectorFormat=Lucene99FlatVectorsFormat()))";
assertEquals(expectedString, customCodec.knnVectorsFormat().toString()); assertEquals(expectedString, customCodec.knnVectorsFormat().toString());
} }
public void testLimits() {
expectThrows(
IllegalArgumentException.class, () -> new Lucene99HnswScalarQuantizedVectorsFormat(-1, 20));
expectThrows(
IllegalArgumentException.class, () -> new Lucene99HnswScalarQuantizedVectorsFormat(0, 20));
expectThrows(
IllegalArgumentException.class, () -> new Lucene99HnswScalarQuantizedVectorsFormat(20, 0));
expectThrows(
IllegalArgumentException.class, () -> new Lucene99HnswScalarQuantizedVectorsFormat(20, -1));
expectThrows(
IllegalArgumentException.class,
() -> new Lucene99HnswScalarQuantizedVectorsFormat(512 + 1, 20));
expectThrows(
IllegalArgumentException.class,
() -> new Lucene99HnswScalarQuantizedVectorsFormat(20, 3201));
expectThrows(
IllegalArgumentException.class,
() -> new Lucene99HnswScalarQuantizedVectorsFormat(20, 100, 0, 1.1f, null));
expectThrows(
IllegalArgumentException.class,
() -> new Lucene99HnswScalarQuantizedVectorsFormat(20, 100, 0, 0.8f, null));
expectThrows(
IllegalArgumentException.class,
() -> new Lucene99HnswScalarQuantizedVectorsFormat(20, 100, 100, null, null));
expectThrows(
IllegalArgumentException.class,
() ->
new Lucene99HnswScalarQuantizedVectorsFormat(
20, 100, 1, null, new SameThreadExecutorService()));
}
} }

View File

@ -21,6 +21,7 @@ import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
import org.apache.lucene.tests.util.TestUtil; import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.SameThreadExecutorService;
public class TestLucene99HnswVectorsFormat extends BaseKnnVectorsFormatTestCase { public class TestLucene99HnswVectorsFormat extends BaseKnnVectorsFormatTestCase {
@Override @Override
@ -48,5 +49,10 @@ public class TestLucene99HnswVectorsFormat extends BaseKnnVectorsFormatTestCase
expectThrows(IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(20, -1)); expectThrows(IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(20, -1));
expectThrows(IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(512 + 1, 20)); expectThrows(IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(512 + 1, 20));
expectThrows(IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(20, 3201)); expectThrows(IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(20, 3201));
expectThrows(
IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(20, 100, 100, null));
expectThrows(
IllegalArgumentException.class,
() -> new Lucene99HnswVectorsFormat(20, 100, 1, new SameThreadExecutorService()));
} }
} }

View File

@ -32,10 +32,11 @@ public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
int numVecs = 100; int numVecs = 100;
float[][] floats = randomFloats(numVecs, dims); float[][] floats = randomFloats(numVecs, dims);
for (float quantile : new float[] {0.9f, 0.95f, 0.99f, (1 - 1f / (dims + 1)), 1f}) { for (float confidenceInterval : new float[] {0.9f, 0.95f, 0.99f, (1 - 1f / (dims + 1)), 1f}) {
float error = Math.max((100 - quantile) * 0.01f, 0.01f); float error = Math.max((100 - confidenceInterval) * 0.01f, 0.01f);
FloatVectorValues floatVectorValues = fromFloats(floats); FloatVectorValues floatVectorValues = fromFloats(floats);
ScalarQuantizer scalarQuantizer = ScalarQuantizer.fromVectors(floatVectorValues, quantile); ScalarQuantizer scalarQuantizer =
ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval);
byte[][] quantized = new byte[floats.length][]; byte[][] quantized = new byte[floats.length][];
float[] offsets = float[] offsets =
quantizeVectors(scalarQuantizer, floats, quantized, VectorSimilarityFunction.EUCLIDEAN); quantizeVectors(scalarQuantizer, floats, quantized, VectorSimilarityFunction.EUCLIDEAN);
@ -61,10 +62,11 @@ public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
float[][] floats = randomFloats(numVecs, dims); float[][] floats = randomFloats(numVecs, dims);
for (float quantile : new float[] {0.9f, 0.95f, 0.99f, (1 - 1f / (dims + 1)), 1f}) { for (float confidenceInterval : new float[] {0.9f, 0.95f, 0.99f, (1 - 1f / (dims + 1)), 1f}) {
float error = Math.max((100 - quantile) * 0.01f, 0.01f); float error = Math.max((100 - confidenceInterval) * 0.01f, 0.01f);
FloatVectorValues floatVectorValues = fromFloatsNormalized(floats); FloatVectorValues floatVectorValues = fromFloatsNormalized(floats);
ScalarQuantizer scalarQuantizer = ScalarQuantizer.fromVectors(floatVectorValues, quantile); ScalarQuantizer scalarQuantizer =
ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval);
byte[][] quantized = new byte[floats.length][]; byte[][] quantized = new byte[floats.length][];
float[] offsets = float[] offsets =
quantizeVectorsNormalized( quantizeVectorsNormalized(
@ -94,10 +96,11 @@ public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
for (float[] fs : floats) { for (float[] fs : floats) {
VectorUtil.l2normalize(fs); VectorUtil.l2normalize(fs);
} }
for (float quantile : new float[] {0.9f, 0.95f, 0.99f, (1 - 1f / (dims + 1)), 1f}) { for (float confidenceInterval : new float[] {0.9f, 0.95f, 0.99f, (1 - 1f / (dims + 1)), 1f}) {
float error = Math.max((100 - quantile) * 0.01f, 0.01f); float error = Math.max((100 - confidenceInterval) * 0.01f, 0.01f);
FloatVectorValues floatVectorValues = fromFloats(floats); FloatVectorValues floatVectorValues = fromFloats(floats);
ScalarQuantizer scalarQuantizer = ScalarQuantizer.fromVectors(floatVectorValues, quantile); ScalarQuantizer scalarQuantizer =
ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval);
byte[][] quantized = new byte[floats.length][]; byte[][] quantized = new byte[floats.length][];
float[] offsets = float[] offsets =
quantizeVectors(scalarQuantizer, floats, quantized, VectorSimilarityFunction.DOT_PRODUCT); quantizeVectors(scalarQuantizer, floats, quantized, VectorSimilarityFunction.DOT_PRODUCT);
@ -123,10 +126,11 @@ public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
int numVecs = 100; int numVecs = 100;
float[][] floats = randomFloats(numVecs, dims); float[][] floats = randomFloats(numVecs, dims);
for (float quantile : new float[] {0.9f, 0.95f, 0.99f, (1 - 1f / (dims + 1)), 1f}) { for (float confidenceInterval : new float[] {0.9f, 0.95f, 0.99f, (1 - 1f / (dims + 1)), 1f}) {
float error = Math.max((100 - quantile) * 0.5f, 0.5f); float error = Math.max((100 - confidenceInterval) * 0.5f, 0.5f);
FloatVectorValues floatVectorValues = fromFloats(floats); FloatVectorValues floatVectorValues = fromFloats(floats);
ScalarQuantizer scalarQuantizer = ScalarQuantizer.fromVectors(floatVectorValues, quantile); ScalarQuantizer scalarQuantizer =
ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval);
byte[][] quantized = new byte[floats.length][]; byte[][] quantized = new byte[floats.length][];
float[] offsets = float[] offsets =
quantizeVectors( quantizeVectors(