Expand scalar quantization with adding half-byte (int4) quantization (#13197)

This PR is a culmination of some various streams of work:

 - Confidence interval optimizations, unlocked even smaller quantization bytes.
 - The ability to quantize down smaller than just int8 or int7
 - Adding an optimized int4 (halfbyte) vector API comparison for dot-product.

The idea of further scalar quantization gives users the choice between:

 - Further quantizing to gain space through compressing the bits into single byte values
 - Or allowing quantization to give guarantees around maximal values that afford faster vector operations.

I didn't add more panama vector APIs as I think trying to micro-optimize int4 for anything other than dot-product was a fools errand. Additionally, I only focused on ARM. I experimented with trying to get better performance on other architectures, but didn't get very far, so I fall back to dotProduct.
This commit is contained in:
Benjamin Trent 2024-04-02 13:38:40 -04:00 committed by GitHub
parent bf193a7125
commit 07d3be59af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 1074 additions and 211 deletions

View File

@ -218,6 +218,9 @@ New Features
This may improve paging logic especially when large segments are merged under memory pressure.
(Uwe Schindler, Chris Hegarty, Robert Muir, Adrien Grand)
* GITHUB#13197: Expand support for new scalar bit levels for HNSW vectors. This includes 4-bit vectors and an option
to compress them to gain a 50% reduction in memory usage. (Ben Trent)
Improvements
---------------------

View File

@ -511,7 +511,7 @@ public class TestBasicBackwardsCompatibility extends BackwardsCompatibilityTestB
}
}
private static ScoreDoc[] assertKNNSearch(
static ScoreDoc[] assertKNNSearch(
IndexSearcher searcher,
float[] queryVector,
int k,

View File

@ -82,6 +82,16 @@ public class TestGenerateBwcIndices extends LuceneTestCase {
sortedTest.createBWCIndex();
}
public void testCreateInt8HNSWIndices() throws IOException {
TestInt8HnswBackwardsCompatibility int8HnswBackwardsCompatibility =
new TestInt8HnswBackwardsCompatibility(
Version.LATEST,
createPattern(
TestInt8HnswBackwardsCompatibility.INDEX_NAME,
TestInt8HnswBackwardsCompatibility.SUFFIX));
int8HnswBackwardsCompatibility.createBWCIndex();
}
private boolean isInitialMajorVersionRelease() {
return Version.LATEST.equals(Version.fromBits(Version.LATEST.major, 0, 0));
}

View File

@ -0,0 +1,150 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.backward_index;
import static org.apache.lucene.backward_index.TestBasicBackwardsCompatibility.assertKNNSearch;
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
import java.io.IOException;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99Codec;
import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.NoMergePolicy;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.Version;
public class TestInt8HnswBackwardsCompatibility extends BackwardsCompatibilityTestBase {
static final String INDEX_NAME = "int8_hnsw";
static final String SUFFIX = "";
private static final Version FIRST_INT8_HNSW_VERSION = Version.LUCENE_9_10_1;
private static final String KNN_VECTOR_FIELD = "knn_field";
private static final int DOC_COUNT = 30;
private static final FieldType KNN_VECTOR_FIELD_TYPE =
KnnFloatVectorField.createFieldType(3, VectorSimilarityFunction.COSINE);
private static final float[] KNN_VECTOR = {0.2f, -0.1f, 0.1f};
public TestInt8HnswBackwardsCompatibility(Version version, String pattern) {
super(version, pattern);
}
/** Provides all sorted versions to the test-framework */
@ParametersFactory(argumentFormatting = "Lucene-Version:%1$s; Pattern: %2$s")
public static Iterable<Object[]> testVersionsFactory() throws IllegalAccessException {
return allVersion(INDEX_NAME, SUFFIX);
}
protected Codec getCodec() {
return new Lucene99Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene99HnswScalarQuantizedVectorsFormat(
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
}
};
}
@Override
protected boolean supportsVersion(Version version) {
return version.onOrAfter(FIRST_INT8_HNSW_VERSION);
}
@Override
void verifyUsesDefaultCodec(Directory dir, String name) throws IOException {
// We don't use the default codec
}
public void testInt8HnswIndexAndSearch() throws Exception {
IndexWriterConfig indexWriterConfig =
newIndexWriterConfig(new MockAnalyzer(random()))
.setOpenMode(IndexWriterConfig.OpenMode.APPEND)
.setCodec(getCodec())
.setMergePolicy(newLogMergePolicy());
try (IndexWriter writer = new IndexWriter(directory, indexWriterConfig)) {
// add 10 docs
for (int i = 0; i < 10; i++) {
writer.addDocument(knnDocument(i + DOC_COUNT));
if (random().nextBoolean()) {
writer.flush();
}
}
if (random().nextBoolean()) {
writer.forceMerge(1);
}
writer.commit();
try (IndexReader reader = DirectoryReader.open(directory)) {
IndexSearcher searcher = new IndexSearcher(reader);
assertKNNSearch(searcher, KNN_VECTOR, 1000, DOC_COUNT + 10, "0");
assertKNNSearch(searcher, KNN_VECTOR, 10, 10, "0");
}
}
// This will confirm the docs are really sorted
TestUtil.checkIndex(directory);
}
@Override
protected void createIndex(Directory dir) throws IOException {
IndexWriterConfig conf =
new IndexWriterConfig(new MockAnalyzer(random()))
.setMaxBufferedDocs(10)
.setCodec(TestUtil.getDefaultCodec())
.setMergePolicy(NoMergePolicy.INSTANCE);
try (IndexWriter writer = new IndexWriter(dir, conf)) {
for (int i = 0; i < DOC_COUNT; i++) {
writer.addDocument(knnDocument(i));
}
writer.forceMerge(1);
}
try (DirectoryReader reader = DirectoryReader.open(dir)) {
IndexSearcher searcher = new IndexSearcher(reader);
assertKNNSearch(searcher, KNN_VECTOR, 1000, DOC_COUNT, "0");
assertKNNSearch(searcher, KNN_VECTOR, 10, 10, "0");
}
}
private static Document knnDocument(int id) {
Document doc = new Document();
float[] vector = {KNN_VECTOR[0], KNN_VECTOR[1], KNN_VECTOR[2] + 0.1f * id};
doc.add(new KnnFloatVectorField(KNN_VECTOR_FIELD, vector, KNN_VECTOR_FIELD_TYPE));
doc.add(new StringField("id", Integer.toString(id), Field.Store.YES));
return doc;
}
public void testReadOldIndices() throws Exception {
try (DirectoryReader reader = DirectoryReader.open(directory)) {
IndexSearcher searcher = new IndexSearcher(reader);
assertKNNSearch(searcher, KNN_VECTOR, 1000, DOC_COUNT, "0");
assertKNNSearch(searcher, KNN_VECTOR, 10, 10, "0");
}
}
}

View File

@ -36,6 +36,8 @@ public class VectorUtilBenchmark {
private byte[] bytesA;
private byte[] bytesB;
private byte[] halfBytesA;
private byte[] halfBytesB;
private float[] floatsA;
private float[] floatsB;
@ -51,6 +53,14 @@ public class VectorUtilBenchmark {
bytesB = new byte[size];
random.nextBytes(bytesA);
random.nextBytes(bytesB);
// random half byte arrays for binary methods
// this means that all values must be between 0 and 15
halfBytesA = new byte[size];
halfBytesB = new byte[size];
for (int i = 0; i < size; ++i) {
halfBytesA[i] = (byte) random.nextInt(16);
halfBytesB[i] = (byte) random.nextInt(16);
}
// random float arrays for float methods
floatsA = new float[size];
@ -94,6 +104,17 @@ public class VectorUtilBenchmark {
return VectorUtil.squareDistance(bytesA, bytesB);
}
@Benchmark
public int binaryHalfByteScalar() {
return VectorUtil.int4DotProduct(halfBytesA, halfBytesB);
}
@Benchmark
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public int binaryHalfByteVector() {
return VectorUtil.int4DotProduct(halfBytesA, halfBytesB);
}
@Benchmark
public float floatCosineScalar() {
return VectorUtil.cosine(floatsA, floatsB);

View File

@ -65,7 +65,7 @@ public final class Lucene99HnswScalarQuantizedVectorsFormat extends KnnVectorsFo
/** Constructs a format using default graph construction parameters */
public Lucene99HnswScalarQuantizedVectorsFormat() {
this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null, null);
this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, 7, true, null, null);
}
/**
@ -75,7 +75,7 @@ public final class Lucene99HnswScalarQuantizedVectorsFormat extends KnnVectorsFo
* @param beamWidth the size of the queue maintained during graph construction.
*/
public Lucene99HnswScalarQuantizedVectorsFormat(int maxConn, int beamWidth) {
this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null, null);
this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, 7, true, null, null);
}
/**
@ -85,6 +85,11 @@ public final class Lucene99HnswScalarQuantizedVectorsFormat extends KnnVectorsFo
* @param beamWidth the size of the queue maintained during graph construction.
* @param numMergeWorkers number of workers (threads) that will be used when doing merge. If
* larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec
* @param bits the number of bits to use for scalar quantization (must be between 1 and 8,
* inclusive)
* @param compress whether to compress the vectors, if true, the vectors that are quantized with
* lte 4 bits will be compressed into a single byte. If false, the vectors will be stored as
* is. This provides a trade-off of memory usage and speed.
* @param confidenceInterval the confidenceInterval for scalar quantizing the vectors, when `null`
* it is calculated based on the vector field dimensions.
* @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are
@ -94,6 +99,8 @@ public final class Lucene99HnswScalarQuantizedVectorsFormat extends KnnVectorsFo
int maxConn,
int beamWidth,
int numMergeWorkers,
int bits,
boolean compress,
Float confidenceInterval,
ExecutorService mergeExec) {
super("Lucene99HnswScalarQuantizedVectorsFormat");
@ -127,7 +134,8 @@ public final class Lucene99HnswScalarQuantizedVectorsFormat extends KnnVectorsFo
} else {
this.mergeExec = null;
}
this.flatVectorsFormat = new Lucene99ScalarQuantizedVectorsFormat(confidenceInterval);
this.flatVectorsFormat =
new Lucene99ScalarQuantizedVectorsFormat(confidenceInterval, bits, compress);
}
@Override

View File

@ -30,12 +30,17 @@ import org.apache.lucene.index.SegmentWriteState;
* @lucene.experimental
*/
public final class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
// The bits that are allowed for scalar quantization
// We only allow unsigned byte (8), signed byte (7), and half-byte (4)
private static final int ALLOWED_BITS = (1 << 8) | (1 << 7) | (1 << 4);
public static final String QUANTIZED_VECTOR_COMPONENT = "QVEC";
static final String NAME = "Lucene99ScalarQuantizedVectorsFormat";
static final int VERSION_START = 0;
static final int VERSION_CURRENT = VERSION_START;
static final int VERSION_ADD_BITS = 1;
static final int VERSION_CURRENT = VERSION_ADD_BITS;
static final String META_CODEC_NAME = "Lucene99ScalarQuantizedVectorsFormatMeta";
static final String VECTOR_DATA_CODEC_NAME = "Lucene99ScalarQuantizedVectorsFormatData";
static final String META_EXTENSION = "vemq";
@ -55,18 +60,27 @@ public final class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsForma
*/
final Float confidenceInterval;
final byte bits;
final boolean compress;
/** Constructs a format using default graph construction parameters */
public Lucene99ScalarQuantizedVectorsFormat() {
this(null);
this(null, 7, true);
}
/**
* Constructs a format using the given graph construction parameters.
*
* @param confidenceInterval the confidenceInterval for scalar quantizing the vectors, when `null`
* it is calculated based on the vector field dimensions.
* it is calculated dynamically.
* @param bits the number of bits to use for scalar quantization (must be between 1 and 8,
* inclusive)
* @param compress whether to compress the vectors, if true, the vectors that are quantized with
* lte 4 bits will be compressed into a single byte. If false, the vectors will be stored as
* is. This provides a trade-off of memory usage and speed.
*/
public Lucene99ScalarQuantizedVectorsFormat(Float confidenceInterval) {
public Lucene99ScalarQuantizedVectorsFormat(
Float confidenceInterval, int bits, boolean compress) {
if (confidenceInterval != null
&& (confidenceInterval < MINIMUM_CONFIDENCE_INTERVAL
|| confidenceInterval > MAXIMUM_CONFIDENCE_INTERVAL)) {
@ -78,7 +92,12 @@ public final class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsForma
+ "; confidenceInterval="
+ confidenceInterval);
}
if (bits < 1 || bits > 8 || (ALLOWED_BITS & (1 << bits)) == 0) {
throw new IllegalArgumentException("bits must be one of: 4, 7, 8; bits=" + bits);
}
this.bits = (byte) bits;
this.confidenceInterval = confidenceInterval;
this.compress = compress;
}
public static float calculateDefaultConfidenceInterval(int vectorDimension) {
@ -92,6 +111,10 @@ public final class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsForma
+ NAME
+ ", confidenceInterval="
+ confidenceInterval
+ ", bits="
+ bits
+ ", compress="
+ compress
+ ", rawVectorFormat="
+ rawVectorFormat
+ ")";
@ -100,7 +123,7 @@ public final class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsForma
@Override
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99ScalarQuantizedVectorsWriter(
state, confidenceInterval, rawVectorFormat.fieldsWriter(state));
state, confidenceInterval, bits, compress, rawVectorFormat.fieldsWriter(state));
}
@Override

View File

@ -82,7 +82,7 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
Lucene99ScalarQuantizedVectorsFormat.VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix);
readFields(meta, state.fieldInfos);
readFields(meta, versionMeta, state.fieldInfos);
} catch (Throwable exception) {
priorE = exception;
} finally {
@ -102,13 +102,14 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
}
}
private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException {
private void readFields(ChecksumIndexInput meta, int versionMeta, FieldInfos infos)
throws IOException {
for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) {
FieldInfo info = infos.fieldInfo(fieldNumber);
if (info == null) {
throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
}
FieldEntry fieldEntry = readField(meta, info);
FieldEntry fieldEntry = readField(meta, versionMeta, info);
validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry);
}
@ -126,8 +127,13 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
+ fieldEntry.dimension);
}
// int8 quantized and calculated stored offset.
long quantizedVectorBytes = dimension + Float.BYTES;
final long quantizedVectorBytes;
if (fieldEntry.bits <= 4 && fieldEntry.compress) {
quantizedVectorBytes = ((dimension + 1) >> 1) + Float.BYTES;
} else {
// int8 quantized and calculated stored offset.
quantizedVectorBytes = dimension + Float.BYTES;
}
long numQuantizedVectorBytes = Math.multiplyExact(quantizedVectorBytes, fieldEntry.size);
if (numQuantizedVectorBytes != fieldEntry.vectorDataLength) {
throw new IllegalStateException(
@ -209,6 +215,8 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
fieldEntry.ordToDoc,
fieldEntry.dimension,
fieldEntry.size,
fieldEntry.bits,
fieldEntry.compress,
fieldEntry.vectorDataOffset,
fieldEntry.vectorDataLength,
quantizedVectorData);
@ -236,7 +244,8 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
return size;
}
private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException {
private FieldEntry readField(IndexInput input, int versionMeta, FieldInfo info)
throws IOException {
VectorEncoding vectorEncoding = readVectorEncoding(input);
VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
if (similarityFunction != info.getVectorSimilarityFunction()) {
@ -248,7 +257,7 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
+ " != "
+ info.getVectorSimilarityFunction());
}
return new FieldEntry(input, vectorEncoding, info.getVectorSimilarityFunction());
return new FieldEntry(input, versionMeta, vectorEncoding, info.getVectorSimilarityFunction());
}
@Override
@ -261,6 +270,8 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
fieldEntry.ordToDoc,
fieldEntry.dimension,
fieldEntry.size,
fieldEntry.bits,
fieldEntry.compress,
fieldEntry.vectorDataOffset,
fieldEntry.vectorDataLength,
quantizedVectorData);
@ -285,10 +296,13 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
final long vectorDataLength;
final ScalarQuantizer scalarQuantizer;
final int size;
final byte bits;
final boolean compress;
final OrdToDocDISIReaderConfiguration ordToDoc;
FieldEntry(
IndexInput input,
int versionMeta,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction)
throws IOException {
@ -299,12 +313,29 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
dimension = input.readVInt();
size = input.readInt();
if (size > 0) {
float confidenceInterval = Float.intBitsToFloat(input.readInt());
float minQuantile = Float.intBitsToFloat(input.readInt());
float maxQuantile = Float.intBitsToFloat(input.readInt());
scalarQuantizer = new ScalarQuantizer(minQuantile, maxQuantile, confidenceInterval);
if (versionMeta < Lucene99ScalarQuantizedVectorsFormat.VERSION_ADD_BITS) {
int floatBits = input.readInt(); // confidenceInterval, unused
if (floatBits == -1) {
throw new CorruptIndexException(
"Missing confidence interval for scalar quantizer", input);
}
this.bits = (byte) 7;
this.compress = false;
float minQuantile = Float.intBitsToFloat(input.readInt());
float maxQuantile = Float.intBitsToFloat(input.readInt());
scalarQuantizer = new ScalarQuantizer(minQuantile, maxQuantile, (byte) 7);
} else {
input.readInt(); // confidenceInterval, unused
this.bits = input.readByte();
this.compress = input.readByte() == 1;
float minQuantile = Float.intBitsToFloat(input.readInt());
float maxQuantile = Float.intBitsToFloat(input.readInt());
scalarQuantizer = new ScalarQuantizer(minQuantile, maxQuantile, bits);
}
} else {
scalarQuantizer = null;
this.bits = (byte) 7;
this.compress = false;
}
ordToDoc = OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size);
}

View File

@ -96,12 +96,20 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
private final IndexOutput meta, quantizedVectorData;
private final Float confidenceInterval;
private final FlatVectorsWriter rawVectorDelegate;
private final byte bits;
private final boolean compress;
private boolean finished;
public Lucene99ScalarQuantizedVectorsWriter(
SegmentWriteState state, Float confidenceInterval, FlatVectorsWriter rawVectorDelegate)
SegmentWriteState state,
Float confidenceInterval,
byte bits,
boolean compress,
FlatVectorsWriter rawVectorDelegate)
throws IOException {
this.confidenceInterval = confidenceInterval;
this.bits = bits;
this.compress = compress;
segmentWriteState = state;
String metaFileName =
IndexFileNames.segmentFileName(
@ -145,12 +153,21 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
public FlatFieldVectorsWriter<?> addField(
FieldInfo fieldInfo, KnnFieldVectorsWriter<?> indexWriter) throws IOException {
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
float confidenceInterval =
this.confidenceInterval == null
? calculateDefaultConfidenceInterval(fieldInfo.getVectorDimension())
: this.confidenceInterval;
if (bits <= 4 && fieldInfo.getVectorDimension() % 2 != 0) {
throw new IllegalArgumentException(
"bits="
+ bits
+ " is not supported for odd vector dimensions; vector dimension="
+ fieldInfo.getVectorDimension());
}
FieldWriter quantizedWriter =
new FieldWriter(confidenceInterval, fieldInfo, segmentWriteState.infoStream, indexWriter);
new FieldWriter(
confidenceInterval,
bits,
compress,
fieldInfo,
segmentWriteState.infoStream,
indexWriter);
fields.add(quantizedWriter);
indexWriter = quantizedWriter;
}
@ -164,24 +181,23 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
// the vectors directly to the new segment.
// No need to use temporary file as we don't have to re-open for reading
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
ScalarQuantizer mergedQuantizationState = mergeQuantiles(fieldInfo, mergeState);
ScalarQuantizer mergedQuantizationState =
mergeAndRecalculateQuantiles(mergeState, fieldInfo, confidenceInterval, bits);
MergedQuantizedVectorValues byteVectorValues =
MergedQuantizedVectorValues.mergeQuantizedByteVectorValues(
fieldInfo, mergeState, mergedQuantizationState);
long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES);
DocsWithFieldSet docsWithField =
writeQuantizedVectorData(quantizedVectorData, byteVectorValues);
writeQuantizedVectorData(quantizedVectorData, byteVectorValues, bits, compress);
long vectorDataLength = quantizedVectorData.getFilePointer() - vectorDataOffset;
float confidenceInterval =
this.confidenceInterval == null
? calculateDefaultConfidenceInterval(fieldInfo.getVectorDimension())
: this.confidenceInterval;
writeMeta(
fieldInfo,
segmentWriteState.segmentInfo.maxDoc(),
vectorDataOffset,
vectorDataLength,
confidenceInterval,
bits,
compress,
mergedQuantizationState.getLowerQuantile(),
mergedQuantizationState.getUpperQuantile(),
docsWithField);
@ -195,7 +211,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
// Simply merge the underlying delegate, which just copies the raw vector data to a new
// segment file
rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
ScalarQuantizer mergedQuantizationState = mergeQuantiles(fieldInfo, mergeState);
ScalarQuantizer mergedQuantizationState =
mergeAndRecalculateQuantiles(mergeState, fieldInfo, confidenceInterval, bits);
return mergeOneFieldToIndex(
segmentWriteState, fieldInfo, mergeState, mergedQuantizationState);
}
@ -255,6 +272,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
vectorDataOffset,
vectorDataLength,
confidenceInterval,
bits,
compress,
fieldData.minQuantile,
fieldData.maxQuantile,
fieldData.docsWithField);
@ -266,6 +285,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
long vectorDataOffset,
long vectorDataLength,
Float confidenceInterval,
byte bits,
boolean compress,
Float lowerQuantile,
Float upperQuantile,
DocsWithFieldSet docsWithField)
@ -280,11 +301,9 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
meta.writeInt(count);
if (count > 0) {
assert Float.isFinite(lowerQuantile) && Float.isFinite(upperQuantile);
meta.writeInt(
Float.floatToIntBits(
confidenceInterval != null
? confidenceInterval
: calculateDefaultConfidenceInterval(field.getVectorDimension())));
meta.writeInt(confidenceInterval == null ? -1 : Float.floatToIntBits(confidenceInterval));
meta.writeByte(bits);
meta.writeByte(compress ? (byte) 1 : (byte) 0);
meta.writeInt(Float.floatToIntBits(lowerQuantile));
meta.writeInt(Float.floatToIntBits(upperQuantile));
}
@ -296,6 +315,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
private void writeQuantizedVectors(FieldWriter fieldData) throws IOException {
ScalarQuantizer scalarQuantizer = fieldData.createQuantizer();
byte[] vector = new byte[fieldData.fieldInfo.getVectorDimension()];
byte[] compressedVector =
fieldData.compress
? OffHeapQuantizedByteVectorValues.compressedArray(
fieldData.fieldInfo.getVectorDimension(), bits)
: null;
final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
float[] copy = fieldData.normalize ? new float[fieldData.fieldInfo.getVectorDimension()] : null;
for (float[] v : fieldData.floatVectors) {
@ -307,7 +331,12 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
float offsetCorrection =
scalarQuantizer.quantize(v, vector, fieldData.fieldInfo.getVectorSimilarityFunction());
quantizedVectorData.writeBytes(vector, vector.length);
if (compressedVector != null) {
OffHeapQuantizedByteVectorValues.compressBytes(vector, compressedVector);
quantizedVectorData.writeBytes(compressedVector, compressedVector.length);
} else {
quantizedVectorData.writeBytes(vector, vector.length);
}
offsetBuffer.putFloat(offsetCorrection);
quantizedVectorData.writeBytes(offsetBuffer.array(), offsetBuffer.array().length);
offsetBuffer.rewind();
@ -348,6 +377,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
vectorDataOffset,
quantizedVectorLength,
confidenceInterval,
bits,
compress,
fieldData.minQuantile,
fieldData.maxQuantile,
newDocsWithField);
@ -356,6 +387,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
private void writeSortedQuantizedVectors(FieldWriter fieldData, int[] ordMap) throws IOException {
ScalarQuantizer scalarQuantizer = fieldData.createQuantizer();
byte[] vector = new byte[fieldData.fieldInfo.getVectorDimension()];
byte[] compressedVector =
fieldData.compress
? OffHeapQuantizedByteVectorValues.compressedArray(
fieldData.fieldInfo.getVectorDimension(), bits)
: null;
final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
float[] copy = fieldData.normalize ? new float[fieldData.fieldInfo.getVectorDimension()] : null;
for (int ordinal : ordMap) {
@ -367,29 +403,35 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
}
float offsetCorrection =
scalarQuantizer.quantize(v, vector, fieldData.fieldInfo.getVectorSimilarityFunction());
quantizedVectorData.writeBytes(vector, vector.length);
if (compressedVector != null) {
OffHeapQuantizedByteVectorValues.compressBytes(vector, compressedVector);
quantizedVectorData.writeBytes(compressedVector, compressedVector.length);
} else {
quantizedVectorData.writeBytes(vector, vector.length);
}
offsetBuffer.putFloat(offsetCorrection);
quantizedVectorData.writeBytes(offsetBuffer.array(), offsetBuffer.array().length);
offsetBuffer.rewind();
}
}
private ScalarQuantizer mergeQuantiles(FieldInfo fieldInfo, MergeState mergeState)
throws IOException {
assert fieldInfo.getVectorEncoding() == VectorEncoding.FLOAT32;
float confidenceInterval =
this.confidenceInterval == null
? calculateDefaultConfidenceInterval(fieldInfo.getVectorDimension())
: this.confidenceInterval;
return mergeAndRecalculateQuantiles(mergeState, fieldInfo, confidenceInterval);
}
private ScalarQuantizedCloseableRandomVectorScorerSupplier mergeOneFieldToIndex(
SegmentWriteState segmentWriteState,
FieldInfo fieldInfo,
MergeState mergeState,
ScalarQuantizer mergedQuantizationState)
throws IOException {
if (segmentWriteState.infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) {
segmentWriteState.infoStream.message(
QUANTIZED_VECTOR_COMPONENT,
"quantized field="
+ " confidenceInterval="
+ confidenceInterval
+ " minQuantile="
+ mergedQuantizationState.getLowerQuantile()
+ " maxQuantile="
+ mergedQuantizationState.getUpperQuantile());
}
long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES);
IndexOutput tempQuantizedVectorData =
segmentWriteState.directory.createTempOutput(
@ -401,7 +443,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
MergedQuantizedVectorValues.mergeQuantizedByteVectorValues(
fieldInfo, mergeState, mergedQuantizationState);
DocsWithFieldSet docsWithField =
writeQuantizedVectorData(tempQuantizedVectorData, byteVectorValues);
writeQuantizedVectorData(tempQuantizedVectorData, byteVectorValues, bits, compress);
CodecUtil.writeFooter(tempQuantizedVectorData);
IOUtils.close(tempQuantizedVectorData);
quantizationDataInput =
@ -421,6 +463,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
vectorDataOffset,
vectorDataLength,
confidenceInterval,
bits,
compress,
mergedQuantizationState.getLowerQuantile(),
mergedQuantizationState.getUpperQuantile(),
docsWithField);
@ -438,6 +482,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
new OffHeapQuantizedByteVectorValues.DenseOffHeapVectorValues(
fieldInfo.getVectorDimension(),
docsWithField.cardinality(),
bits,
compress,
quantizationDataInput)));
} finally {
if (success == false) {
@ -449,9 +495,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
}
static ScalarQuantizer mergeQuantiles(
List<ScalarQuantizer> quantizationStates,
List<Integer> segmentSizes,
float confidenceInterval) {
List<ScalarQuantizer> quantizationStates, List<Integer> segmentSizes, byte bits) {
assert quantizationStates.size() == segmentSizes.size();
if (quantizationStates.isEmpty()) {
return null;
@ -466,10 +510,13 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
lowerQuantile += quantizationStates.get(i).getLowerQuantile() * segmentSizes.get(i);
upperQuantile += quantizationStates.get(i).getUpperQuantile() * segmentSizes.get(i);
totalCount += segmentSizes.get(i);
if (quantizationStates.get(i).getBits() != bits) {
return null;
}
}
lowerQuantile /= totalCount;
upperQuantile /= totalCount;
return new ScalarQuantizer(lowerQuantile, upperQuantile, confidenceInterval);
return new ScalarQuantizer(lowerQuantile, upperQuantile, bits);
}
/**
@ -531,11 +578,14 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
* @param mergeState The merge state
* @param fieldInfo The field info
* @param confidenceInterval The confidence interval
* @param bits The number of bits
* @return The merged quantiles
* @throws IOException If there is a low-level I/O error
*/
public static ScalarQuantizer mergeAndRecalculateQuantiles(
MergeState mergeState, FieldInfo fieldInfo, float confidenceInterval) throws IOException {
MergeState mergeState, FieldInfo fieldInfo, Float confidenceInterval, byte bits)
throws IOException {
assert fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32);
List<ScalarQuantizer> quantizationStates = new ArrayList<>(mergeState.liveDocs.length);
List<Integer> segmentSizes = new ArrayList<>(mergeState.liveDocs.length);
for (int i = 0; i < mergeState.liveDocs.length; i++) {
@ -550,14 +600,17 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
segmentSizes.add(fvv.size());
}
}
ScalarQuantizer mergedQuantiles =
mergeQuantiles(quantizationStates, segmentSizes, confidenceInterval);
ScalarQuantizer mergedQuantiles = mergeQuantiles(quantizationStates, segmentSizes, bits);
// Segments no providing quantization state indicates that their quantiles were never
// calculated.
// To be safe, we should always recalculate given a sample set over all the float vectors in the
// merged
// segment view
if (mergedQuantiles == null || shouldRecomputeQuantiles(mergedQuantiles, quantizationStates)) {
if (mergedQuantiles == null
// For smaller `bits` values, we should always recalculate the quantiles
// TODO: this is very conservative, could we reuse information for even int4 quantization?
|| bits <= 4
|| shouldRecomputeQuantiles(mergedQuantiles, quantizationStates)) {
int numVectors = 0;
FloatVectorValues vectorValues =
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
@ -568,10 +621,17 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
numVectors++;
}
mergedQuantiles =
ScalarQuantizer.fromVectors(
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState),
confidenceInterval,
numVectors);
confidenceInterval == null
? ScalarQuantizer.fromVectorsAutoInterval(
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState),
fieldInfo.getVectorSimilarityFunction(),
numVectors,
bits)
: ScalarQuantizer.fromVectors(
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState),
confidenceInterval,
numVectors,
bits);
}
return mergedQuantiles;
}
@ -600,8 +660,17 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
* Writes the vector values to the output and returns a set of documents that contains vectors.
*/
public static DocsWithFieldSet writeQuantizedVectorData(
IndexOutput output, QuantizedByteVectorValues quantizedByteVectorValues) throws IOException {
IndexOutput output,
QuantizedByteVectorValues quantizedByteVectorValues,
byte bits,
boolean compress)
throws IOException {
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
final byte[] compressedVector =
compress
? OffHeapQuantizedByteVectorValues.compressedArray(
quantizedByteVectorValues.dimension(), bits)
: null;
for (int docV = quantizedByteVectorValues.nextDoc();
docV != NO_MORE_DOCS;
docV = quantizedByteVectorValues.nextDoc()) {
@ -609,7 +678,12 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
byte[] binaryValue = quantizedByteVectorValues.vectorValue();
assert binaryValue.length == quantizedByteVectorValues.dimension()
: "dim=" + quantizedByteVectorValues.dimension() + " len=" + binaryValue.length;
output.writeBytes(binaryValue, binaryValue.length);
if (compressedVector != null) {
OffHeapQuantizedByteVectorValues.compressBytes(binaryValue, compressedVector);
output.writeBytes(compressedVector, compressedVector.length);
} else {
output.writeBytes(binaryValue, binaryValue.length);
}
output.writeInt(Float.floatToIntBits(quantizedByteVectorValues.getScoreCorrectionConstant()));
docsWithField.add(docV);
}
@ -625,7 +699,9 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
private static final long SHALLOW_SIZE = shallowSizeOfInstance(FieldWriter.class);
private final List<float[]> floatVectors;
private final FieldInfo fieldInfo;
private final float confidenceInterval;
private final Float confidenceInterval;
private final byte bits;
private final boolean compress;
private final InfoStream infoStream;
private final boolean normalize;
private float minQuantile = Float.POSITIVE_INFINITY;
@ -635,17 +711,21 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
@SuppressWarnings("unchecked")
FieldWriter(
float confidenceInterval,
Float confidenceInterval,
byte bits,
boolean compress,
FieldInfo fieldInfo,
InfoStream infoStream,
KnnFieldVectorsWriter<?> indexWriter) {
super((KnnFieldVectorsWriter<float[]>) indexWriter);
this.confidenceInterval = confidenceInterval;
this.bits = bits;
this.fieldInfo = fieldInfo;
this.normalize = fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE;
this.floatVectors = new ArrayList<>();
this.infoStream = infoStream;
this.docsWithField = new DocsWithFieldSet();
this.compress = compress;
}
void finish() throws IOException {
@ -656,13 +736,16 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
finished = true;
return;
}
FloatVectorValues floatVectorValues = new FloatVectorWrapper(floatVectors, normalize);
ScalarQuantizer quantizer =
ScalarQuantizer.fromVectors(
new FloatVectorWrapper(
floatVectors,
fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE),
confidenceInterval,
floatVectors.size());
confidenceInterval == null
? ScalarQuantizer.fromVectorsAutoInterval(
floatVectorValues,
fieldInfo.getVectorSimilarityFunction(),
floatVectors.size(),
bits)
: ScalarQuantizer.fromVectors(
floatVectorValues, confidenceInterval, floatVectors.size(), bits);
minQuantile = quantizer.getLowerQuantile();
maxQuantile = quantizer.getUpperQuantile();
if (infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) {
@ -671,6 +754,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
"quantized field="
+ " confidenceInterval="
+ confidenceInterval
+ " bits="
+ bits
+ " minQuantile="
+ minQuantile
+ " maxQuantile="
@ -681,7 +766,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
ScalarQuantizer createQuantizer() {
assert finished;
return new ScalarQuantizer(minQuantile, maxQuantile, confidenceInterval);
return new ScalarQuantizer(minQuantile, maxQuantile, bits);
}
@Override
@ -765,7 +850,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
}
}
private static class QuantizedByteVectorValueSub extends DocIDMerger.Sub {
static class QuantizedByteVectorValueSub extends DocIDMerger.Sub {
private final QuantizedByteVectorValues values;
QuantizedByteVectorValueSub(MergeState.DocMap docMap, QuantizedByteVectorValues values) {
@ -799,6 +884,10 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
// Or we have never been quantized.
if (reader == null
|| reader.getQuantizationState(fieldInfo.name) == null
// For smaller `bits` values, we should always recalculate the quantiles
// TODO: this is very conservative, could we reuse information for even int4
// quantization?
|| scalarQuantizer.getBits() <= 4
|| shouldRequantize(reader.getQuantizationState(fieldInfo.name), scalarQuantizer)) {
sub =
new QuantizedByteVectorValueSub(
@ -884,7 +973,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
}
}
private static class QuantizedFloatVectorValues extends QuantizedByteVectorValues {
static class QuantizedFloatVectorValues extends QuantizedByteVectorValues {
private final FloatVectorValues values;
private final ScalarQuantizer quantizer;
private final byte[] quantizedVector;
@ -999,14 +1088,13 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
}
}
private static final class OffsetCorrectedQuantizedByteVectorValues
extends QuantizedByteVectorValues {
static final class OffsetCorrectedQuantizedByteVectorValues extends QuantizedByteVectorValues {
private final QuantizedByteVectorValues in;
private final VectorSimilarityFunction vectorSimilarityFunction;
private final ScalarQuantizer scalarQuantizer, oldScalarQuantizer;
private OffsetCorrectedQuantizedByteVectorValues(
OffsetCorrectedQuantizedByteVectorValues(
QuantizedByteVectorValues in,
VectorSimilarityFunction vectorSimilarityFunction,
ScalarQuantizer scalarQuantizer,

View File

@ -36,6 +36,10 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
protected final int dimension;
protected final int size;
protected final int numBytes;
protected final byte bits;
protected final boolean compress;
protected final IndexInput slice;
protected final byte[] binaryValue;
protected final ByteBuffer byteBuffer;
@ -43,11 +47,52 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
protected int lastOrd = -1;
protected final float[] scoreCorrectionConstant = new float[1];
OffHeapQuantizedByteVectorValues(int dimension, int size, IndexInput slice) {
static void decompressBytes(byte[] compressed, int numBytes) {
if (numBytes == compressed.length) {
return;
}
if (numBytes << 1 != compressed.length) {
throw new IllegalArgumentException(
"numBytes: " + numBytes + " does not match compressed length: " + compressed.length);
}
for (int i = 0; i < numBytes; ++i) {
compressed[numBytes + i] = (byte) (compressed[i] & 0x0F);
compressed[i] = (byte) ((compressed[i] & 0xFF) >> 4);
}
}
static byte[] compressedArray(int dimension, byte bits) {
if (bits <= 4) {
return new byte[(dimension + 1) >> 1];
} else {
return null;
}
}
static void compressBytes(byte[] raw, byte[] compressed) {
if (compressed.length != ((raw.length + 1) >> 1)) {
throw new IllegalArgumentException(
"compressed length: " + compressed.length + " does not match raw length: " + raw.length);
}
for (int i = 0; i < compressed.length; ++i) {
int v = (raw[i] << 4) | raw[compressed.length + i];
compressed[i] = (byte) v;
}
}
OffHeapQuantizedByteVectorValues(
int dimension, int size, byte bits, boolean compress, IndexInput slice) {
this.dimension = dimension;
this.size = size;
this.slice = slice;
this.byteSize = dimension + Float.BYTES;
this.bits = bits;
this.compress = compress;
if (bits <= 4 && compress) {
this.numBytes = (dimension + 1) >> 1;
} else {
this.numBytes = dimension;
}
this.byteSize = this.numBytes + Float.BYTES;
byteBuffer = ByteBuffer.allocate(dimension);
binaryValue = byteBuffer.array();
}
@ -68,8 +113,9 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
return binaryValue;
}
slice.seek((long) targetOrd * byteSize);
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), dimension);
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), numBytes);
slice.readFloats(scoreCorrectionConstant, 0, 1);
decompressBytes(binaryValue, numBytes);
lastOrd = targetOrd;
return binaryValue;
}
@ -83,6 +129,8 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
OrdToDocDISIReaderConfiguration configuration,
int dimension,
int size,
byte bits,
boolean compress,
long quantizedVectorDataOffset,
long quantizedVectorDataLength,
IndexInput vectorData)
@ -94,9 +142,10 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
vectorData.slice(
"quantized-vector-data", quantizedVectorDataOffset, quantizedVectorDataLength);
if (configuration.isDense()) {
return new DenseOffHeapVectorValues(dimension, size, bytesSlice);
return new DenseOffHeapVectorValues(dimension, size, bits, compress, bytesSlice);
} else {
return new SparseOffHeapVectorValues(configuration, dimension, size, vectorData, bytesSlice);
return new SparseOffHeapVectorValues(
configuration, dimension, size, bits, compress, vectorData, bytesSlice);
}
}
@ -108,8 +157,9 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
private int doc = -1;
public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice) {
super(dimension, size, slice);
public DenseOffHeapVectorValues(
int dimension, int size, byte bits, boolean compress, IndexInput slice) {
super(dimension, size, bits, compress, slice);
}
@Override
@ -138,7 +188,7 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
@Override
public DenseOffHeapVectorValues copy() throws IOException {
return new DenseOffHeapVectorValues(dimension, size, slice.clone());
return new DenseOffHeapVectorValues(dimension, size, bits, compress, slice.clone());
}
@Override
@ -158,10 +208,12 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
OrdToDocDISIReaderConfiguration configuration,
int dimension,
int size,
byte bits,
boolean compress,
IndexInput dataIn,
IndexInput slice)
throws IOException {
super(dimension, size, slice);
super(dimension, size, bits, compress, slice);
this.configuration = configuration;
this.dataIn = dataIn;
this.ordToDoc = configuration.getDirectMonotonicReader(dataIn);
@ -191,7 +243,8 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
@Override
public SparseOffHeapVectorValues copy() throws IOException {
return new SparseOffHeapVectorValues(configuration, dimension, size, dataIn, slice.clone());
return new SparseOffHeapVectorValues(
configuration, dimension, size, bits, compress, dataIn, slice.clone());
}
@Override
@ -221,7 +274,7 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
private static class EmptyOffHeapVectorValues extends OffHeapQuantizedByteVectorValues {
public EmptyOffHeapVectorValues(int dimension) {
super(dimension, 0, null);
super(dimension, 0, (byte) 7, false, null);
}
private int doc = -1;

View File

@ -151,6 +151,11 @@ final class DefaultVectorUtilSupport implements VectorUtilSupport {
return total;
}
@Override
public int int4DotProduct(byte[] a, byte[] b) {
return dotProduct(a, b);
}
@Override
public float cosine(byte[] a, byte[] b) {
// Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14.

View File

@ -36,6 +36,9 @@ public interface VectorUtilSupport {
/** Returns the dot product computed over signed bytes. */
int dotProduct(byte[] a, byte[] b);
/** Returns the dot product over the computed bytes, assuming the values are int4 encoded. */
int int4DotProduct(byte[] a, byte[] b);
/** Returns the cosine similarity between the two byte vectors. */
float cosine(byte[] a, byte[] b);

View File

@ -175,6 +175,13 @@ public final class VectorUtil {
return IMPL.dotProduct(a, b);
}
public static int int4DotProduct(byte[] a, byte[] b) {
if (a.length != b.length) {
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
}
return IMPL.int4DotProduct(a, b);
}
/**
* Dot product score computed over signed bytes, scaled to be in [0, 1].
*

View File

@ -76,7 +76,7 @@ public class ScalarQuantizedRandomVectorScorer
this.queryOffset = correction;
this.similarity =
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
similarityFunction, scalarQuantizer.getConstantMultiplier());
similarityFunction, scalarQuantizer.getConstantMultiplier(), scalarQuantizer.getBits());
this.values = values;
}

View File

@ -37,7 +37,7 @@ public class ScalarQuantizedRandomVectorScorerSupplier implements RandomVectorSc
RandomAccessQuantizedByteVectorValues values) {
this.similarity =
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
similarityFunction, scalarQuantizer.getConstantMultiplier());
similarityFunction, scalarQuantizer.getConstantMultiplier(), scalarQuantizer.getBits());
this.values = values;
}

View File

@ -33,14 +33,17 @@ public interface ScalarQuantizedVectorSimilarity {
*
* @param sim similarity function
* @param constMultiplier constant multiplier used for quantization
* @param bits number of bits used for quantization
* @return a {@link ScalarQuantizedVectorSimilarity} that applies the appropriate corrections
*/
static ScalarQuantizedVectorSimilarity fromVectorSimilarity(
VectorSimilarityFunction sim, float constMultiplier) {
VectorSimilarityFunction sim, float constMultiplier, byte bits) {
return switch (sim) {
case EUCLIDEAN -> new Euclidean(constMultiplier);
case COSINE, DOT_PRODUCT -> new DotProduct(constMultiplier);
case MAXIMUM_INNER_PRODUCT -> new MaximumInnerProduct(constMultiplier);
case COSINE, DOT_PRODUCT -> new DotProduct(
constMultiplier, bits <= 4 ? VectorUtil::int4DotProduct : VectorUtil::dotProduct);
case MAXIMUM_INNER_PRODUCT -> new MaximumInnerProduct(
constMultiplier, bits <= 4 ? VectorUtil::int4DotProduct : VectorUtil::dotProduct);
};
}
@ -66,15 +69,17 @@ public interface ScalarQuantizedVectorSimilarity {
/** Calculates dot product on quantized vectors, applying the appropriate corrections */
class DotProduct implements ScalarQuantizedVectorSimilarity {
private final float constMultiplier;
private final ByteVectorComparator comparator;
public DotProduct(float constMultiplier) {
public DotProduct(float constMultiplier, ByteVectorComparator comparator) {
this.constMultiplier = constMultiplier;
this.comparator = comparator;
}
@Override
public float score(
byte[] queryVector, float queryOffset, byte[] storedVector, float vectorOffset) {
int dotProduct = VectorUtil.dotProduct(storedVector, queryVector);
int dotProduct = comparator.compare(storedVector, queryVector);
float adjustedDistance = dotProduct * constMultiplier + queryOffset + vectorOffset;
return (1 + adjustedDistance) / 2;
}
@ -83,17 +88,24 @@ public interface ScalarQuantizedVectorSimilarity {
/** Calculates max inner product on quantized vectors, applying the appropriate corrections */
class MaximumInnerProduct implements ScalarQuantizedVectorSimilarity {
private final float constMultiplier;
private final ByteVectorComparator comparator;
public MaximumInnerProduct(float constMultiplier) {
public MaximumInnerProduct(float constMultiplier, ByteVectorComparator comparator) {
this.constMultiplier = constMultiplier;
this.comparator = comparator;
}
@Override
public float score(
byte[] queryVector, float queryOffset, byte[] storedVector, float vectorOffset) {
int dotProduct = VectorUtil.dotProduct(storedVector, queryVector);
int dotProduct = comparator.compare(storedVector, queryVector);
float adjustedDistance = dotProduct * constMultiplier + queryOffset + vectorOffset;
return scaleMaxInnerProductScore(adjustedDistance);
}
}
/** Compares two byte vectors */
interface ByteVectorComparator {
int compare(byte[] v1, byte[] v2);
}
}

View File

@ -19,11 +19,15 @@ package org.apache.lucene.util.quantization;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.stream.IntStream;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.util.IntroSelector;
import org.apache.lucene.util.Selector;
@ -74,20 +78,23 @@ public class ScalarQuantizer {
private final float alpha;
private final float scale;
private final float minQuantile, maxQuantile, confidenceInterval;
private final byte bits;
private final float minQuantile, maxQuantile;
/**
* @param minQuantile the lower quantile of the distribution
* @param maxQuantile the upper quantile of the distribution
* @param confidenceInterval The configured confidence interval used to calculate the quantiles.
* @param bits the number of bits to use for quantization
*/
public ScalarQuantizer(float minQuantile, float maxQuantile, float confidenceInterval) {
public ScalarQuantizer(float minQuantile, float maxQuantile, byte bits) {
assert maxQuantile >= minQuantile;
assert bits > 0 && bits <= 8;
this.minQuantile = minQuantile;
this.maxQuantile = maxQuantile;
this.scale = 127f / (maxQuantile - minQuantile);
this.alpha = (maxQuantile - minQuantile) / 127f;
this.confidenceInterval = confidenceInterval;
this.bits = bits;
final float divisor = (float) ((1 << bits) - 1);
this.scale = divisor / (maxQuantile - minQuantile);
this.alpha = (maxQuantile - minQuantile) / divisor;
}
/**
@ -100,31 +107,38 @@ public class ScalarQuantizer {
*/
public float quantize(float[] src, byte[] dest, VectorSimilarityFunction similarityFunction) {
assert src.length == dest.length;
float correctiveOffset = 0f;
float correction = 0;
for (int i = 0; i < src.length; i++) {
float v = src[i];
// Make sure the value is within the quantile range, cutting off the tails
// see first parenthesis in equation: byte = (float - minQuantile) * 127/(maxQuantile -
// minQuantile)
float dx = Math.max(minQuantile, Math.min(maxQuantile, src[i])) - minQuantile;
// Scale the value to the range [0, 127], this is our quantized value
// scale = 127/(maxQuantile - minQuantile)
float dxs = scale * dx;
// We multiply by `alpha` here to get the quantized value back into the original range
// to aid in calculating the corrective offset
float dxq = Math.round(dxs) * alpha;
// Calculate the corrective offset that needs to be applied to the score
// in addition to the `byte * minQuantile * alpha` term in the equation
// we add the `(dx - dxq) * dxq` term to account for the fact that the quantized value
// will be rounded to the nearest whole number and lose some accuracy
// Additionally, we account for the global correction of `minQuantile^2` in the equation
correctiveOffset += minQuantile * (v - minQuantile / 2.0F) + (dx - dxq) * dxq;
dest[i] = (byte) Math.round(dxs);
correction += quantizeFloat(src[i], dest, i);
}
if (similarityFunction.equals(VectorSimilarityFunction.EUCLIDEAN)) {
return 0;
}
return correctiveOffset;
return correction;
}
private float quantizeFloat(float v, byte[] dest, int destIndex) {
assert dest == null || destIndex < dest.length;
// Make sure the value is within the quantile range, cutting off the tails
// see first parenthesis in equation: byte = (float - minQuantile) * 127/(maxQuantile -
// minQuantile)
float dx = v - minQuantile;
float dxc = Math.max(minQuantile, Math.min(maxQuantile, v)) - minQuantile;
// Scale the value to the range [0, 127], this is our quantized value
// scale = 127/(maxQuantile - minQuantile)
float dxs = scale * dxc;
// We multiply by `alpha` here to get the quantized value back into the original range
// to aid in calculating the corrective offset
float dxq = Math.round(dxs) * alpha;
if (dest != null) {
dest[destIndex] = (byte) Math.round(dxs);
}
// Calculate the corrective offset that needs to be applied to the score
// in addition to the `byte * minQuantile * alpha` term in the equation
// we add the `(dx - dxq) * dxq` term to account for the fact that the quantized value
// will be rounded to the nearest whole number and lose some accuracy
// Additionally, we account for the global correction of `minQuantile^2` in the equation
return minQuantile * (v - minQuantile / 2.0F) + (dx - dxq) * dxq;
}
/**
@ -146,10 +160,7 @@ public class ScalarQuantizer {
for (int i = 0; i < quantizedVector.length; i++) {
// dequantize the old value in order to recalculate the corrective offset
float v = (oldQuantizer.alpha * quantizedVector[i]) + oldQuantizer.minQuantile;
float dx = Math.max(minQuantile, Math.min(maxQuantile, v)) - minQuantile;
float dxs = scale * dx;
float dxq = Math.round(dxs) * alpha;
correctiveOffset += minQuantile * (v - minQuantile / 2.0F) + (dx - dxq) * dxq;
correctiveOffset += quantizeFloat(v, null, 0);
}
return correctiveOffset;
}
@ -160,7 +171,7 @@ public class ScalarQuantizer {
* @param src the source vector
* @param dest the destination vector
*/
public void deQuantize(byte[] src, float[] dest) {
void deQuantize(byte[] src, float[] dest) {
assert src.length == dest.length;
for (int i = 0; i < src.length; i++) {
dest[i] = (alpha * src[i]) + minQuantile;
@ -175,14 +186,14 @@ public class ScalarQuantizer {
return maxQuantile;
}
public float getConfidenceInterval() {
return confidenceInterval;
}
public float getConstantMultiplier() {
return alpha * alpha;
}
public byte getBits() {
return bits;
}
@Override
public String toString() {
return "ScalarQuantizer{"
@ -190,14 +201,14 @@ public class ScalarQuantizer {
+ minQuantile
+ ", maxQuantile="
+ maxQuantile
+ ", confidenceInterval="
+ confidenceInterval
+ ", bits="
+ bits
+ '}';
}
private static final Random random = new Random(42);
static int[] reservoirSampleIndices(int numFloatVecs, int sampleSize) {
private static int[] reservoirSampleIndices(int numFloatVecs, int sampleSize) {
int[] vectorsToTake = IntStream.range(0, sampleSize).toArray();
for (int i = sampleSize; i < numFloatVecs; i++) {
int j = random.nextInt(i + 1);
@ -220,26 +231,35 @@ public class ScalarQuantizer {
* @param confidenceInterval the confidence interval used to calculate the quantiles
* @param totalVectorCount the total number of live float vectors in the index. This is vital for
* accounting for deleted documents when calculating the quantiles.
* @param bits the number of bits to use for quantization
* @return A new {@link ScalarQuantizer} instance
* @throws IOException if there is an error reading the float vector values
*/
public static ScalarQuantizer fromVectors(
FloatVectorValues floatVectorValues, float confidenceInterval, int totalVectorCount)
FloatVectorValues floatVectorValues,
float confidenceInterval,
int totalVectorCount,
byte bits)
throws IOException {
return fromVectors(
floatVectorValues, confidenceInterval, totalVectorCount, SCALAR_QUANTIZATION_SAMPLE_SIZE);
floatVectorValues,
confidenceInterval,
totalVectorCount,
bits,
SCALAR_QUANTIZATION_SAMPLE_SIZE);
}
static ScalarQuantizer fromVectors(
FloatVectorValues floatVectorValues,
float confidenceInterval,
int totalVectorCount,
byte bits,
int quantizationSampleSize)
throws IOException {
assert 0.9f <= confidenceInterval && confidenceInterval <= 1f;
assert quantizationSampleSize > SCRATCH_SIZE;
if (totalVectorCount == 0) {
return new ScalarQuantizer(0f, 0f, confidenceInterval);
return new ScalarQuantizer(0f, 0f, bits);
}
if (confidenceInterval == 1f) {
float min = Float.POSITIVE_INFINITY;
@ -250,13 +270,14 @@ public class ScalarQuantizer {
max = Math.max(max, v);
}
}
return new ScalarQuantizer(min, max, confidenceInterval);
return new ScalarQuantizer(min, max, bits);
}
final float[] quantileGatheringScratch =
new float[floatVectorValues.dimension() * Math.min(SCRATCH_SIZE, totalVectorCount)];
int count = 0;
double upperSum = 0;
double lowerSum = 0;
double[] upperSum = new double[1];
double[] lowerSum = new double[1];
float[] confidenceIntervals = new float[] {confidenceInterval};
if (totalVectorCount <= quantizationSampleSize) {
int scratchSize = Math.min(SCRATCH_SIZE, totalVectorCount);
int i = 0;
@ -266,10 +287,7 @@ public class ScalarQuantizer {
vectorValue, 0, quantileGatheringScratch, i * vectorValue.length, vectorValue.length);
i++;
if (i == scratchSize) {
float[] upperAndLower =
getUpperAndLowerQuantile(quantileGatheringScratch, confidenceInterval);
upperSum += upperAndLower[1];
lowerSum += upperAndLower[0];
extractQuantiles(confidenceIntervals, quantileGatheringScratch, upperSum, lowerSum);
i = 0;
count++;
}
@ -277,10 +295,8 @@ public class ScalarQuantizer {
// Note, we purposefully don't use the rest of the scratch state if we have fewer than
// `SCRATCH_SIZE` vectors, mainly because if we are sampling so few vectors then we don't
// want to be adversely affected by the extreme confidence intervals over small sample sizes
return new ScalarQuantizer(
(float) lowerSum / count, (float) upperSum / count, confidenceInterval);
return new ScalarQuantizer((float) lowerSum[0] / count, (float) upperSum[0] / count, bits);
}
// Reservoir sample the vector ordinals we want to read
int[] vectorsToTake = reservoirSampleIndices(totalVectorCount, quantizationSampleSize);
int index = 0;
int idx = 0;
@ -296,16 +312,213 @@ public class ScalarQuantizer {
vectorValue, 0, quantileGatheringScratch, idx * vectorValue.length, vectorValue.length);
idx++;
if (idx == SCRATCH_SIZE) {
float[] upperAndLower =
getUpperAndLowerQuantile(quantileGatheringScratch, confidenceInterval);
upperSum += upperAndLower[1];
lowerSum += upperAndLower[0];
extractQuantiles(confidenceIntervals, quantileGatheringScratch, upperSum, lowerSum);
count++;
idx = 0;
}
}
return new ScalarQuantizer(
(float) lowerSum / count, (float) upperSum / count, confidenceInterval);
return new ScalarQuantizer((float) lowerSum[0] / count, (float) upperSum[0] / count, bits);
}
public static ScalarQuantizer fromVectorsAutoInterval(
FloatVectorValues floatVectorValues,
VectorSimilarityFunction function,
int totalVectorCount,
byte bits)
throws IOException {
if (totalVectorCount == 0) {
return new ScalarQuantizer(0f, 0f, bits);
}
int sampleSize = Math.min(totalVectorCount, 1000);
final float[] quantileGatheringScratch =
new float[floatVectorValues.dimension() * Math.min(SCRATCH_SIZE, totalVectorCount)];
int count = 0;
double[] upperSum = new double[2];
double[] lowerSum = new double[2];
final List<float[]> sampledDocs = new ArrayList<>(sampleSize);
float[] confidenceIntervals =
new float[] {
1
- Math.min(32, floatVectorValues.dimension() / 10f)
/ (floatVectorValues.dimension() + 1),
1 - 1f / (floatVectorValues.dimension() + 1)
};
if (totalVectorCount <= sampleSize) {
int scratchSize = Math.min(SCRATCH_SIZE, totalVectorCount);
int i = 0;
while (floatVectorValues.nextDoc() != NO_MORE_DOCS) {
gatherSample(floatVectorValues, quantileGatheringScratch, sampledDocs, i);
i++;
if (i == scratchSize) {
extractQuantiles(confidenceIntervals, quantileGatheringScratch, upperSum, lowerSum);
i = 0;
count++;
}
}
} else {
// Reservoir sample the vector ordinals we want to read
int[] vectorsToTake = reservoirSampleIndices(totalVectorCount, 1000);
// TODO make this faster by .advance()ing & dual iterator
int index = 0;
int idx = 0;
for (int i : vectorsToTake) {
while (index <= i) {
// We cannot use `advance(docId)` as MergedVectorValues does not support it
floatVectorValues.nextDoc();
index++;
}
assert floatVectorValues.docID() != NO_MORE_DOCS;
gatherSample(floatVectorValues, quantileGatheringScratch, sampledDocs, idx);
idx++;
if (idx == SCRATCH_SIZE) {
extractQuantiles(confidenceIntervals, quantileGatheringScratch, upperSum, lowerSum);
count++;
idx = 0;
}
}
}
// Here we gather the upper and lower bounds for the quantile grid search
float al = (float) lowerSum[1] / count;
float bu = (float) upperSum[1] / count;
final float au = (float) lowerSum[0] / count;
final float bl = (float) upperSum[0] / count;
final float[] lowerCandidates = new float[16];
final float[] upperCandidates = new float[16];
int idx = 0;
for (float i = 0f; i < 32f; i += 2f) {
lowerCandidates[idx] = al + i * (au - al) / 32f;
upperCandidates[idx] = bl + i * (bu - bl) / 32f;
idx++;
}
// Now we need to find the best candidate pair by correlating the true quantized nearest
// neighbor scores
// with the float vector scores
List<ScoreDocsAndScoreVariance> nearestNeighbors = findNearestNeighbors(sampledDocs, function);
float[] bestPair =
candidateGridSearch(
nearestNeighbors, sampledDocs, lowerCandidates, upperCandidates, function, bits);
return new ScalarQuantizer(bestPair[0], bestPair[1], bits);
}
private static void extractQuantiles(
float[] confidenceIntervals,
float[] quantileGatheringScratch,
double[] upperSum,
double[] lowerSum) {
assert confidenceIntervals.length == upperSum.length
&& confidenceIntervals.length == lowerSum.length;
for (int i = 0; i < confidenceIntervals.length; i++) {
float[] upperAndLower =
getUpperAndLowerQuantile(quantileGatheringScratch, confidenceIntervals[i]);
upperSum[i] += upperAndLower[1];
lowerSum[i] += upperAndLower[0];
}
}
private static void gatherSample(
FloatVectorValues floatVectorValues,
float[] quantileGatheringScratch,
List<float[]> sampledDocs,
int i)
throws IOException {
float[] vectorValue = floatVectorValues.vectorValue();
float[] copy = new float[vectorValue.length];
System.arraycopy(vectorValue, 0, copy, 0, vectorValue.length);
sampledDocs.add(copy);
System.arraycopy(
vectorValue, 0, quantileGatheringScratch, i * vectorValue.length, vectorValue.length);
}
private static float[] candidateGridSearch(
List<ScoreDocsAndScoreVariance> nearestNeighbors,
List<float[]> vectors,
float[] lowerCandidates,
float[] upperCandidates,
VectorSimilarityFunction function,
byte bits) {
double maxCorr = Double.NEGATIVE_INFINITY;
float bestLower = 0f;
float bestUpper = 0f;
ScoreErrorCorrelator scoreErrorCorrelator =
new ScoreErrorCorrelator(function, nearestNeighbors, vectors, bits);
// first do a coarse grained search to find the initial best candidate pair
int bestQuandrantLower = 0;
int bestQuandrantUpper = 0;
for (int i = 0; i < lowerCandidates.length; i += 4) {
float lower = lowerCandidates[i];
for (int j = 0; j < upperCandidates.length; j += 4) {
float upper = upperCandidates[j];
if (upper <= lower) {
continue;
}
double mean = scoreErrorCorrelator.scoreErrorCorrelation(lower, upper);
if (mean > maxCorr) {
maxCorr = mean;
bestLower = lower;
bestUpper = upper;
bestQuandrantLower = i;
bestQuandrantUpper = j;
}
}
}
// Now search within the best quadrant
for (int i = bestQuandrantLower + 1; i < bestQuandrantLower + 4; i++) {
for (int j = bestQuandrantUpper + 1; j < bestQuandrantUpper + 4; j++) {
float lower = lowerCandidates[i];
float upper = upperCandidates[j];
if (upper <= lower) {
continue;
}
double mean = scoreErrorCorrelator.scoreErrorCorrelation(lower, upper);
if (mean > maxCorr) {
maxCorr = mean;
bestLower = lower;
bestUpper = upper;
}
}
}
return new float[] {bestLower, bestUpper};
}
/**
* @param vectors The vectors to find the nearest neighbors for each other
* @param similarityFunction The similarity function to use
* @return The top 10 nearest neighbors for each vector from the vectors list
*/
private static List<ScoreDocsAndScoreVariance> findNearestNeighbors(
List<float[]> vectors, VectorSimilarityFunction similarityFunction) {
List<HitQueue> queues = new ArrayList<>(vectors.size());
queues.add(new HitQueue(10, false));
for (int i = 0; i < vectors.size(); i++) {
float[] vector = vectors.get(i);
for (int j = i + 1; j < vectors.size(); j++) {
float[] otherVector = vectors.get(j);
float score = similarityFunction.compare(vector, otherVector);
// initialize the rest of the queues
if (queues.size() <= j) {
queues.add(new HitQueue(10, false));
}
queues.get(i).insertWithOverflow(new ScoreDoc(j, score));
queues.get(j).insertWithOverflow(new ScoreDoc(i, score));
}
}
// Extract the top 10 from each queue
List<ScoreDocsAndScoreVariance> result = new ArrayList<>(vectors.size());
OnlineMeanAndVar meanAndVar = new OnlineMeanAndVar();
for (int i = 0; i < vectors.size(); i++) {
HitQueue queue = queues.get(i);
ScoreDoc[] scoreDocs = new ScoreDoc[queue.size()];
for (int j = queue.size() - 1; j >= 0; j--) {
scoreDocs[j] = queue.pop();
assert scoreDocs[j] != null;
meanAndVar.add(scoreDocs[j].score);
}
result.add(new ScoreDocsAndScoreVariance(scoreDocs, meanAndVar.var()));
meanAndVar.reset();
}
return result;
}
/**
@ -319,9 +532,8 @@ public class ScalarQuantizer {
* @return lower and upper quantile values
*/
static float[] getUpperAndLowerQuantile(float[] arr, float confidenceInterval) {
assert 0.9f <= confidenceInterval && confidenceInterval <= 1f;
int selectorIndex = (int) (arr.length * (1f - confidenceInterval) / 2f + 0.5f);
if (selectorIndex > 0) {
if (selectorIndex > 0 && arr.length > 2) {
Selector selector = new FloatSelector(arr);
selector.select(0, arr.length, arr.length - selectorIndex);
selector.select(0, arr.length - selectorIndex, selectorIndex);
@ -361,4 +573,95 @@ public class ScalarQuantizer {
arr[j] = tmp;
}
}
private static class ScoreDocsAndScoreVariance {
private final ScoreDoc[] scoreDocs;
private final float scoreVariance;
public ScoreDocsAndScoreVariance(ScoreDoc[] scoreDocs, float scoreVariance) {
this.scoreDocs = scoreDocs;
this.scoreVariance = scoreVariance;
}
public ScoreDoc[] getScoreDocs() {
return scoreDocs;
}
}
private static class OnlineMeanAndVar {
private double mean = 0.0;
private double var = 0.0;
private int n = 0;
void reset() {
mean = 0.0;
var = 0.0;
n = 0;
}
void add(double x) {
n++;
double delta = x - mean;
mean += delta / n;
var += delta * (x - mean);
}
float var() {
return (float) (var / (n - 1));
}
}
/**
* This class is used to correlate the scores of the nearest neighbors with the errors in the
* scores. This is used to find the best quantile pair for the scalar quantizer.
*/
private static class ScoreErrorCorrelator {
private final OnlineMeanAndVar corr = new OnlineMeanAndVar();
private final OnlineMeanAndVar errors = new OnlineMeanAndVar();
private final VectorSimilarityFunction function;
private final List<ScoreDocsAndScoreVariance> nearestNeighbors;
private final List<float[]> vectors;
private final byte[] query;
private final byte[] vector;
private final byte bits;
public ScoreErrorCorrelator(
VectorSimilarityFunction function,
List<ScoreDocsAndScoreVariance> nearestNeighbors,
List<float[]> vectors,
byte bits) {
this.function = function;
this.nearestNeighbors = nearestNeighbors;
this.vectors = vectors;
this.query = new byte[vectors.get(0).length];
this.vector = new byte[vectors.get(0).length];
this.bits = bits;
}
double scoreErrorCorrelation(float lowerQuantile, float upperQuantile) {
corr.reset();
ScalarQuantizer quantizer = new ScalarQuantizer(lowerQuantile, upperQuantile, bits);
ScalarQuantizedVectorSimilarity scalarQuantizedVectorSimilarity =
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
function, quantizer.getConstantMultiplier(), quantizer.bits);
for (int i = 0; i < nearestNeighbors.size(); i++) {
float queryCorrection = quantizer.quantize(vectors.get(i), query, function);
ScoreDocsAndScoreVariance scoreDocsAndScoreVariance = nearestNeighbors.get(i);
ScoreDoc[] scoreDocs = scoreDocsAndScoreVariance.getScoreDocs();
float scoreVariance = scoreDocsAndScoreVariance.scoreVariance;
// calculate the score for the vector against its nearest neighbors but with quantized
// scores now
errors.reset();
for (ScoreDoc scoreDoc : scoreDocs) {
float vectorCorrection = quantizer.quantize(vectors.get(scoreDoc.doc), vector, function);
float qScore =
scalarQuantizedVectorSimilarity.score(
query, queryCorrection, vector, vectorCorrection);
errors.add(qScore - scoreDoc.score);
}
corr.add(1 - errors.var() / scoreVariance);
}
return corr.mean;
}
}
}

View File

@ -389,6 +389,53 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
return acc.reduceLanes(ADD);
}
@Override
public int int4DotProduct(byte[] a, byte[] b) {
int i = 0;
int res = 0;
if (VECTOR_BITSIZE >= 512 || VECTOR_BITSIZE == 256) {
return dotProduct(a, b);
} else if (a.length >= 32 && HAS_FAST_INTEGER_VECTORS) {
i += ByteVector.SPECIES_128.loopBound(a.length);
res += int4DotProductBody128(a, b, i);
}
// scalar tail
for (; i < a.length; i++) {
res += b[i] * a[i];
}
return res;
}
private int int4DotProductBody128(byte[] a, byte[] b, int limit) {
int sum = 0;
// iterate in chunks of 1024 items to ensure we don't overflow the short accumulator
for (int i = 0; i < limit; i += 1024) {
ShortVector acc0 = ShortVector.zero(ShortVector.SPECIES_128);
ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_128);
int innerLimit = Math.min(limit - i, 1024);
for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) {
ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i + j);
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i + j);
ByteVector prod8 = va8.mul(vb8);
ShortVector prod16 =
prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts();
acc0 = acc0.add(prod16.and((short) 0xFF));
va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i + j + 8);
vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i + j + 8);
prod8 = va8.mul(vb8);
prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts();
acc1 = acc1.add(prod16.and((short) 0xFF));
}
IntVector intAcc0 = acc0.convertShape(S2I, IntVector.SPECIES_128, 0).reinterpretAsInts();
IntVector intAcc1 = acc0.convertShape(S2I, IntVector.SPECIES_128, 1).reinterpretAsInts();
IntVector intAcc2 = acc1.convertShape(S2I, IntVector.SPECIES_128, 0).reinterpretAsInts();
IntVector intAcc3 = acc1.convertShape(S2I, IntVector.SPECIES_128, 1).reinterpretAsInts();
sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reduceLanes(ADD);
}
return sum;
}
@Override
public float cosine(byte[] a, byte[] b) {
int i = 0;

View File

@ -42,17 +42,37 @@ import org.apache.lucene.util.SameThreadExecutorService;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizer;
import org.junit.Before;
public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase {
KnnVectorsFormat format;
Float confidenceInterval;
int bits;
@Before
@Override
public void setUp() throws Exception {
bits = random().nextBoolean() ? 4 : 7;
confidenceInterval = random().nextBoolean() ? 0.99f : null;
format =
new Lucene99HnswScalarQuantizedVectorsFormat(
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
1,
bits,
random().nextBoolean(),
confidenceInterval,
null);
super.setUp();
}
@Override
protected Codec getCodec() {
return new Lucene99Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene99HnswScalarQuantizedVectorsFormat(
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
return format;
}
};
}
@ -63,17 +83,25 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
VectorSimilarityFunction similarityFunction = randomSimilarity();
boolean normalize = similarityFunction == VectorSimilarityFunction.COSINE;
int dim = random().nextInt(64) + 1;
if (dim % 2 == 1) {
dim++;
}
List<float[]> vectors = new ArrayList<>(numVectors);
for (int i = 0; i < numVectors; i++) {
vectors.add(randomVector(dim));
}
float confidenceInterval =
Lucene99ScalarQuantizedVectorsFormat.calculateDefaultConfidenceInterval(dim);
ScalarQuantizer scalarQuantizer =
ScalarQuantizer.fromVectors(
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, normalize),
confidenceInterval,
numVectors);
confidenceInterval == null
? ScalarQuantizer.fromVectorsAutoInterval(
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, normalize),
similarityFunction,
numVectors,
(byte) bits)
: ScalarQuantizer.fromVectors(
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, normalize),
confidenceInterval,
numVectors,
(byte) bits);
float[] expectedCorrections = new float[numVectors];
byte[][] expectedVectors = new byte[numVectors][];
for (int i = 0; i < numVectors; i++) {
@ -149,11 +177,12 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
new FilterCodec("foo", Codec.getDefault()) {
@Override
public KnnVectorsFormat knnVectorsFormat() {
return new Lucene99HnswScalarQuantizedVectorsFormat(10, 20, 1, 0.9f, null);
return new Lucene99HnswScalarQuantizedVectorsFormat(
10, 20, 1, (byte) 4, false, 0.9f, null);
}
};
String expectedString =
"Lucene99HnswScalarQuantizedVectorsFormat(name=Lucene99HnswScalarQuantizedVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, confidenceInterval=0.9, rawVectorFormat=Lucene99FlatVectorsFormat()))";
"Lucene99HnswScalarQuantizedVectorsFormat(name=Lucene99HnswScalarQuantizedVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, confidenceInterval=0.9, bits=4, compress=false, rawVectorFormat=Lucene99FlatVectorsFormat()))";
assertEquals(expectedString, customCodec.knnVectorsFormat().toString());
}
@ -174,15 +203,28 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
() -> new Lucene99HnswScalarQuantizedVectorsFormat(20, 3201));
expectThrows(
IllegalArgumentException.class,
() -> new Lucene99HnswScalarQuantizedVectorsFormat(20, 100, 0, 1.1f, null));
() -> new Lucene99HnswScalarQuantizedVectorsFormat(20, 100, 0, 7, false, 1.1f, null));
expectThrows(
IllegalArgumentException.class,
() -> new Lucene99HnswScalarQuantizedVectorsFormat(20, 100, 0, 0.8f, null));
() -> new Lucene99HnswScalarQuantizedVectorsFormat(20, 100, 0, -1, false, null, null));
expectThrows(
IllegalArgumentException.class,
() -> new Lucene99HnswScalarQuantizedVectorsFormat(20, 100, 0, 5, false, null, null));
expectThrows(
IllegalArgumentException.class,
() -> new Lucene99HnswScalarQuantizedVectorsFormat(20, 100, 0, 9, false, null, null));
expectThrows(
IllegalArgumentException.class,
() -> new Lucene99HnswScalarQuantizedVectorsFormat(20, 100, 0, 7, false, 0.8f, null));
expectThrows(
IllegalArgumentException.class,
() -> new Lucene99HnswScalarQuantizedVectorsFormat(20, 100, 100, 7, false, null, null));
expectThrows(
IllegalArgumentException.class,
() ->
new Lucene99HnswScalarQuantizedVectorsFormat(
20, 100, 1, null, new SameThreadExecutorService()));
20, 100, 1, 7, false, null, new SameThreadExecutorService()));
}
// Ensures that all expected vector similarity functions are translatable

View File

@ -39,14 +39,17 @@ public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
float error = Math.max((100 - confidenceInterval) * 0.01f, 0.01f);
FloatVectorValues floatVectorValues = fromFloats(floats);
ScalarQuantizer scalarQuantizer =
ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, floats.length);
ScalarQuantizer.fromVectors(
floatVectorValues, confidenceInterval, floats.length, (byte) 7);
byte[][] quantized = new byte[floats.length][];
float[] offsets =
quantizeVectors(scalarQuantizer, floats, quantized, VectorSimilarityFunction.EUCLIDEAN);
float[] query = ArrayUtil.copyOfSubArray(floats[0], 0, dims);
ScalarQuantizedVectorSimilarity quantizedSimilarity =
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
VectorSimilarityFunction.EUCLIDEAN, scalarQuantizer.getConstantMultiplier());
VectorSimilarityFunction.EUCLIDEAN,
scalarQuantizer.getConstantMultiplier(),
scalarQuantizer.getBits());
assertQuantizedScores(
floats,
quantized,
@ -69,7 +72,8 @@ public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
float error = Math.max((100 - confidenceInterval) * 0.01f, 0.01f);
FloatVectorValues floatVectorValues = fromFloatsNormalized(floats, null);
ScalarQuantizer scalarQuantizer =
ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, floats.length);
ScalarQuantizer.fromVectors(
floatVectorValues, confidenceInterval, floats.length, (byte) 7);
byte[][] quantized = new byte[floats.length][];
float[] offsets =
quantizeVectorsNormalized(
@ -78,7 +82,9 @@ public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
VectorUtil.l2normalize(query);
ScalarQuantizedVectorSimilarity quantizedSimilarity =
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
VectorSimilarityFunction.COSINE, scalarQuantizer.getConstantMultiplier());
VectorSimilarityFunction.COSINE,
scalarQuantizer.getConstantMultiplier(),
scalarQuantizer.getBits());
assertQuantizedScores(
floats,
quantized,
@ -103,7 +109,8 @@ public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
float error = Math.max((100 - confidenceInterval) * 0.01f, 0.01f);
FloatVectorValues floatVectorValues = fromFloats(floats);
ScalarQuantizer scalarQuantizer =
ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, floats.length);
ScalarQuantizer.fromVectors(
floatVectorValues, confidenceInterval, floats.length, (byte) 7);
byte[][] quantized = new byte[floats.length][];
float[] offsets =
quantizeVectors(scalarQuantizer, floats, quantized, VectorSimilarityFunction.DOT_PRODUCT);
@ -111,7 +118,9 @@ public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
VectorUtil.l2normalize(query);
ScalarQuantizedVectorSimilarity quantizedSimilarity =
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
VectorSimilarityFunction.DOT_PRODUCT, scalarQuantizer.getConstantMultiplier());
VectorSimilarityFunction.DOT_PRODUCT,
scalarQuantizer.getConstantMultiplier(),
scalarQuantizer.getBits());
assertQuantizedScores(
floats,
quantized,
@ -133,7 +142,8 @@ public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
float error = Math.max((100 - confidenceInterval) * 0.5f, 0.5f);
FloatVectorValues floatVectorValues = fromFloats(floats);
ScalarQuantizer scalarQuantizer =
ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, floats.length);
ScalarQuantizer.fromVectors(
floatVectorValues, confidenceInterval, floats.length, (byte) 7);
byte[][] quantized = new byte[floats.length][];
float[] offsets =
quantizeVectors(
@ -142,7 +152,8 @@ public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
ScalarQuantizedVectorSimilarity quantizedSimilarity =
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT,
scalarQuantizer.getConstantMultiplier());
scalarQuantizer.getConstantMultiplier(),
scalarQuantizer.getBits());
assertQuantizedScores(
floats,
quantized,

View File

@ -34,7 +34,8 @@ public class TestScalarQuantizer extends LuceneTestCase {
float[][] floats = randomFloats(numVecs, dims);
FloatVectorValues floatVectorValues = fromFloats(floats);
ScalarQuantizer scalarQuantizer = ScalarQuantizer.fromVectors(floatVectorValues, 1, numVecs);
ScalarQuantizer scalarQuantizer =
ScalarQuantizer.fromVectors(floatVectorValues, 1, numVecs, (byte) 7);
float[] dequantized = new float[dims];
byte[] quantized = new byte[dims];
byte[] requantized = new byte[dims];
@ -87,6 +88,7 @@ public class TestScalarQuantizer extends LuceneTestCase {
floatVectorValues,
0.99f,
floatVectorValues.numLiveVectors,
(byte) 7,
Math.max(floatVectorValues.numLiveVectors - 1, SCRATCH_SIZE + 1));
}
{
@ -96,6 +98,7 @@ public class TestScalarQuantizer extends LuceneTestCase {
floatVectorValues,
0.99f,
floatVectorValues.numLiveVectors,
(byte) 7,
Math.max(floatVectorValues.numLiveVectors - 1, SCRATCH_SIZE + 1));
}
{
@ -105,6 +108,7 @@ public class TestScalarQuantizer extends LuceneTestCase {
floatVectorValues,
0.99f,
floatVectorValues.numLiveVectors,
(byte) 7,
Math.max(floatVectorValues.numLiveVectors - 1, SCRATCH_SIZE + 1));
}
{
@ -114,10 +118,36 @@ public class TestScalarQuantizer extends LuceneTestCase {
floatVectorValues,
0.99f,
floatVectorValues.numLiveVectors,
(byte) 7,
Math.max(random().nextInt(floatVectorValues.floats.length - 1) + 1, SCRATCH_SIZE + 1));
}
}
public void testFromVectorsAutoInterval() throws IOException {
int dims = 128;
int numVecs = 100;
VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
float[][] floats = randomFloats(numVecs, dims);
FloatVectorValues floatVectorValues = fromFloats(floats);
ScalarQuantizer scalarQuantizer =
ScalarQuantizer.fromVectorsAutoInterval(
floatVectorValues, similarityFunction, numVecs, (byte) 4);
assertNotNull(scalarQuantizer);
float[] dequantized = new float[dims];
byte[] quantized = new byte[dims];
byte[] requantized = new byte[dims];
for (int i = 0; i < numVecs; i++) {
scalarQuantizer.quantize(floats[i], quantized, similarityFunction);
scalarQuantizer.deQuantize(quantized, dequantized);
scalarQuantizer.quantize(dequantized, requantized, similarityFunction);
for (int j = 0; j < dims; j++) {
assertEquals(dequantized[j], floats[i][j], 0.2);
assertEquals(quantized[j], requantized[j]);
}
}
}
static void shuffleArray(float[] ar) {
for (int i = ar.length - 1; i > 0; i--) {
int index = random().nextInt(i + 1);

View File

@ -127,12 +127,12 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
w.addDocument(doc);
Document doc2 = new Document();
doc2.add(new KnnFloatVectorField("f", new float[3], VectorSimilarityFunction.DOT_PRODUCT));
doc2.add(new KnnFloatVectorField("f", new float[6], VectorSimilarityFunction.DOT_PRODUCT));
IllegalArgumentException expected =
expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc2));
String errMsg =
"Inconsistency of field data structures across documents for field [f] of doc [1]."
+ " vector dimension: expected '4', but it has '3'.";
+ " vector dimension: expected '4', but it has '6'.";
assertEquals(errMsg, expected.getMessage());
}
@ -145,12 +145,12 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
w.commit();
Document doc2 = new Document();
doc2.add(new KnnFloatVectorField("f", new float[3], VectorSimilarityFunction.DOT_PRODUCT));
doc2.add(new KnnFloatVectorField("f", new float[6], VectorSimilarityFunction.DOT_PRODUCT));
IllegalArgumentException expected =
expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc2));
String errMsg =
"cannot change field \"f\" from vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
+ "to inconsistent vector dimension=3, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT";
+ "to inconsistent vector dimension=6, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT";
assertEquals(errMsg, expected.getMessage());
}
}
@ -202,12 +202,12 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
try (IndexWriter w2 = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc2 = new Document();
doc2.add(new KnnFloatVectorField("f", new float[1], VectorSimilarityFunction.DOT_PRODUCT));
doc2.add(new KnnFloatVectorField("f", new float[2], VectorSimilarityFunction.DOT_PRODUCT));
IllegalArgumentException expected =
expectThrows(IllegalArgumentException.class, () -> w2.addDocument(doc2));
assertEquals(
"cannot change field \"f\" from vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
+ "to inconsistent vector dimension=1, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
+ "to inconsistent vector dimension=2, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
expected.getMessage());
}
}
@ -284,7 +284,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
public void testAddIndexesDirectory01() throws Exception {
String fieldName = "field";
float[] vector = new float[1];
float[] vector = new float[2];
Document doc = new Document();
doc.add(new KnnFloatVectorField(fieldName, vector, VectorSimilarityFunction.DOT_PRODUCT));
try (Directory dir = newDirectory();
@ -294,6 +294,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
}
try (IndexWriter w2 = new IndexWriter(dir2, newIndexWriterConfig())) {
vector[0] = 1;
vector[1] = 1;
w2.addDocument(doc);
w2.addIndexes(dir);
w2.forceMerge(1);
@ -322,13 +323,13 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
}
try (IndexWriter w2 = new IndexWriter(dir2, newIndexWriterConfig())) {
Document doc = new Document();
doc.add(new KnnFloatVectorField("f", new float[5], VectorSimilarityFunction.DOT_PRODUCT));
doc.add(new KnnFloatVectorField("f", new float[6], VectorSimilarityFunction.DOT_PRODUCT));
w2.addDocument(doc);
IllegalArgumentException expected =
expectThrows(
IllegalArgumentException.class, () -> w2.addIndexes(new Directory[] {dir}));
assertEquals(
"cannot change field \"f\" from vector dimension=5, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
"cannot change field \"f\" from vector dimension=6, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
+ "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
expected.getMessage());
}
@ -367,7 +368,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
}
try (IndexWriter w2 = new IndexWriter(dir2, newIndexWriterConfig())) {
Document doc = new Document();
doc.add(new KnnFloatVectorField("f", new float[5], VectorSimilarityFunction.DOT_PRODUCT));
doc.add(new KnnFloatVectorField("f", new float[6], VectorSimilarityFunction.DOT_PRODUCT));
w2.addDocument(doc);
try (DirectoryReader r = DirectoryReader.open(dir)) {
IllegalArgumentException expected =
@ -375,7 +376,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
IllegalArgumentException.class,
() -> w2.addIndexes(new CodecReader[] {(CodecReader) getOnlyLeafReader(r)}));
assertEquals(
"cannot change field \"f\" from vector dimension=5, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
"cannot change field \"f\" from vector dimension=6, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
+ "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
expected.getMessage());
}
@ -419,13 +420,13 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
}
try (IndexWriter w2 = new IndexWriter(dir2, newIndexWriterConfig())) {
Document doc = new Document();
doc.add(new KnnFloatVectorField("f", new float[5], VectorSimilarityFunction.DOT_PRODUCT));
doc.add(new KnnFloatVectorField("f", new float[6], VectorSimilarityFunction.DOT_PRODUCT));
w2.addDocument(doc);
try (DirectoryReader r = DirectoryReader.open(dir)) {
IllegalArgumentException expected =
expectThrows(IllegalArgumentException.class, () -> TestUtil.addIndexesSlowly(w2, r));
assertEquals(
"cannot change field \"f\" from vector dimension=5, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
"cannot change field \"f\" from vector dimension=6, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
+ "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
expected.getMessage());
}
@ -486,7 +487,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
.contains("vector's dimensions must be <= [" + getVectorsMaxDimensions("f") + "]"));
Document doc2 = new Document();
doc2.add(new KnnFloatVectorField("f", new float[1], VectorSimilarityFunction.DOT_PRODUCT));
doc2.add(new KnnFloatVectorField("f", new float[2], VectorSimilarityFunction.DOT_PRODUCT));
w.addDocument(doc2);
Document doc3 = new Document();
@ -531,7 +532,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
assertEquals("cannot index an empty vector", e.getMessage());
Document doc2 = new Document();
doc2.add(new KnnFloatVectorField("f", new float[1], VectorSimilarityFunction.EUCLIDEAN));
doc2.add(new KnnFloatVectorField("f", new float[2], VectorSimilarityFunction.EUCLIDEAN));
w.addDocument(doc2);
}
}
@ -592,7 +593,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
doc.add(new StringField("id", "0", Field.Store.NO));
doc.add(
new KnnFloatVectorField(
"v", new float[] {2, 3, 5}, VectorSimilarityFunction.DOT_PRODUCT));
"v", new float[] {2, 3, 5, 6}, VectorSimilarityFunction.DOT_PRODUCT));
w.addDocument(doc);
w.addDocument(new Document());
w.commit();
@ -613,7 +614,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
// assert that knn search doesn't fail on a field with all deleted docs
TopDocs results =
leafReader.searchNearestVectors(
"v", randomNormalizedVector(3), 1, leafReader.getLiveDocs(), Integer.MAX_VALUE);
"v", randomNormalizedVector(4), 1, leafReader.getLiveDocs(), Integer.MAX_VALUE);
assertEquals(0, results.scoreDocs.length);
}
}
@ -626,14 +627,14 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
doc.add(new StringField("id", "0", Field.Store.NO));
doc.add(
new KnnFloatVectorField(
"v0", new float[] {2, 3, 5}, VectorSimilarityFunction.DOT_PRODUCT));
"v0", new float[] {2, 3, 5, 6}, VectorSimilarityFunction.DOT_PRODUCT));
w.addDocument(doc);
w.commit();
doc = new Document();
doc.add(
new KnnFloatVectorField(
"v1", new float[] {2, 3, 5}, VectorSimilarityFunction.DOT_PRODUCT));
"v1", new float[] {2, 3, 5, 6}, VectorSimilarityFunction.DOT_PRODUCT));
w.addDocument(doc);
w.forceMerge(1);
}
@ -649,6 +650,9 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
VectorEncoding[] fieldVectorEncodings = new VectorEncoding[numFields];
for (int i = 0; i < numFields; i++) {
fieldDims[i] = random().nextInt(20) + 1;
if (fieldDims[i] % 2 != 0) {
fieldDims[i]++;
}
fieldSimilarityFunctions[i] = randomSimilarity();
fieldVectorEncodings[i] = randomVectorEncoding();
}
@ -731,7 +735,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
// We copy indexed values (as for BinaryDocValues) so the input float[] can be reused across
// calls to IndexWriter.addDocument.
String fieldName = "field";
float[] v = {0};
float[] v = {0, 0};
try (Directory dir = newDirectory();
IndexWriter iw =
new IndexWriter(
@ -829,25 +833,25 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
IndexWriter iw =
new IndexWriter(dir, newIndexWriterConfig().setMergePolicy(newLogMergePolicy()))) {
Document doc = new Document();
float[] v = new float[] {1};
float[] v = new float[] {1, 2};
doc.add(new KnnFloatVectorField("field1", v, VectorSimilarityFunction.EUCLIDEAN));
doc.add(
new KnnFloatVectorField(
"field2", new float[] {1, 2, 3}, VectorSimilarityFunction.EUCLIDEAN));
"field2", new float[] {1, 2, 3, 4}, VectorSimilarityFunction.EUCLIDEAN));
iw.addDocument(doc);
v[0] = 2;
iw.addDocument(doc);
doc = new Document();
doc.add(
new KnnFloatVectorField(
"field3", new float[] {1, 2, 3}, VectorSimilarityFunction.DOT_PRODUCT));
"field3", new float[] {1, 2, 3, 4}, VectorSimilarityFunction.DOT_PRODUCT));
iw.addDocument(doc);
iw.forceMerge(1);
try (IndexReader reader = DirectoryReader.open(iw)) {
LeafReader leaf = reader.leaves().get(0).reader();
FloatVectorValues vectorValues = leaf.getFloatVectorValues("field1");
assertEquals(1, vectorValues.dimension());
assertEquals(2, vectorValues.dimension());
assertEquals(2, vectorValues.size());
vectorValues.nextDoc();
assertEquals(1f, vectorValues.vectorValue()[0], 0);
@ -856,7 +860,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
assertEquals(NO_MORE_DOCS, vectorValues.nextDoc());
FloatVectorValues vectorValues2 = leaf.getFloatVectorValues("field2");
assertEquals(3, vectorValues2.dimension());
assertEquals(4, vectorValues2.dimension());
assertEquals(2, vectorValues2.size());
vectorValues2.nextDoc();
assertEquals(2f, vectorValues2.vectorValue()[1], 0);
@ -865,7 +869,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
assertEquals(NO_MORE_DOCS, vectorValues2.nextDoc());
FloatVectorValues vectorValues3 = leaf.getFloatVectorValues("field3");
assertEquals(3, vectorValues3.dimension());
assertEquals(4, vectorValues3.dimension());
assertEquals(1, vectorValues3.size());
vectorValues3.nextDoc();
assertEquals(1f, vectorValues3.vectorValue()[0], 0.1);
@ -889,6 +893,9 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
IndexWriter iw = new IndexWriter(dir, iwc)) {
int numDoc = atLeast(100);
int dimension = atLeast(10);
if (dimension % 2 != 0) {
dimension++;
}
float[] scratch = new float[dimension];
int numValues = 0;
float[][] values = new float[numDoc][];
@ -965,6 +972,9 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
IndexWriter iw = new IndexWriter(dir, iwc)) {
int numDoc = atLeast(100);
int dimension = atLeast(10);
if (dimension % 2 != 0) {
dimension++;
}
byte[] scratch = new byte[dimension];
int numValues = 0;
BytesRef[] values = new BytesRef[numDoc];
@ -1101,6 +1111,9 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
IndexWriter iw = new IndexWriter(dir, iwc)) {
int numDoc = atLeast(100);
int dimension = atLeast(10);
if (dimension % 2 != 0) {
dimension++;
}
float[][] id2value = new float[numDoc][];
for (int i = 0; i < numDoc; i++) {
int id = random().nextInt(numDoc);
@ -1252,7 +1265,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
return v;
}
private byte[] randomVector8(int dim) {
protected byte[] randomVector8(int dim) {
assert dim > 0;
float[] v = randomNormalizedVector(dim);
byte[] b = new byte[dim];
@ -1268,12 +1281,12 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
Document doc = new Document();
doc.add(
new KnnFloatVectorField(
"v1", randomNormalizedVector(3), VectorSimilarityFunction.EUCLIDEAN));
"v1", randomNormalizedVector(4), VectorSimilarityFunction.EUCLIDEAN));
w.addDocument(doc);
doc.add(
new KnnFloatVectorField(
"v2", randomNormalizedVector(3), VectorSimilarityFunction.EUCLIDEAN));
"v2", randomNormalizedVector(4), VectorSimilarityFunction.EUCLIDEAN));
w.addDocument(doc);
}
@ -1360,7 +1373,10 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
public void testVectorValuesReportCorrectDocs() throws Exception {
final int numDocs = atLeast(1000);
final int dim = random().nextInt(20) + 1;
int dim = random().nextInt(20) + 1;
if (dim % 2 != 0) {
dim++;
}
double fieldValuesCheckSum = 0;
int fieldDocCount = 0;
long fieldSumDocIDs = 0;