mirror of https://github.com/apache/lucene.git
Expand scalar quantization with adding half-byte (int4) quantization (#13197)
This PR is a culmination of some various streams of work: - Confidence interval optimizations, unlocked even smaller quantization bytes. - The ability to quantize down smaller than just int8 or int7 - Adding an optimized int4 (halfbyte) vector API comparison for dot-product. The idea of further scalar quantization gives users the choice between: - Further quantizing to gain space through compressing the bits into single byte values - Or allowing quantization to give guarantees around maximal values that afford faster vector operations. I didn't add more panama vector APIs as I think trying to micro-optimize int4 for anything other than dot-product was a fools errand. Additionally, I only focused on ARM. I experimented with trying to get better performance on other architectures, but didn't get very far, so I fall back to dotProduct.
This commit is contained in:
parent
bf193a7125
commit
07d3be59af
|
@ -218,6 +218,9 @@ New Features
|
|||
This may improve paging logic especially when large segments are merged under memory pressure.
|
||||
(Uwe Schindler, Chris Hegarty, Robert Muir, Adrien Grand)
|
||||
|
||||
* GITHUB#13197: Expand support for new scalar bit levels for HNSW vectors. This includes 4-bit vectors and an option
|
||||
to compress them to gain a 50% reduction in memory usage. (Ben Trent)
|
||||
|
||||
Improvements
|
||||
---------------------
|
||||
|
||||
|
|
|
@ -511,7 +511,7 @@ public class TestBasicBackwardsCompatibility extends BackwardsCompatibilityTestB
|
|||
}
|
||||
}
|
||||
|
||||
private static ScoreDoc[] assertKNNSearch(
|
||||
static ScoreDoc[] assertKNNSearch(
|
||||
IndexSearcher searcher,
|
||||
float[] queryVector,
|
||||
int k,
|
||||
|
|
|
@ -82,6 +82,16 @@ public class TestGenerateBwcIndices extends LuceneTestCase {
|
|||
sortedTest.createBWCIndex();
|
||||
}
|
||||
|
||||
public void testCreateInt8HNSWIndices() throws IOException {
|
||||
TestInt8HnswBackwardsCompatibility int8HnswBackwardsCompatibility =
|
||||
new TestInt8HnswBackwardsCompatibility(
|
||||
Version.LATEST,
|
||||
createPattern(
|
||||
TestInt8HnswBackwardsCompatibility.INDEX_NAME,
|
||||
TestInt8HnswBackwardsCompatibility.SUFFIX));
|
||||
int8HnswBackwardsCompatibility.createBWCIndex();
|
||||
}
|
||||
|
||||
private boolean isInitialMajorVersionRelease() {
|
||||
return Version.LATEST.equals(Version.fromBits(Version.LATEST.major, 0, 0));
|
||||
}
|
||||
|
|
|
@ -0,0 +1,150 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.backward_index;
|
||||
|
||||
import static org.apache.lucene.backward_index.TestBasicBackwardsCompatibility.assertKNNSearch;
|
||||
|
||||
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.codecs.Codec;
|
||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.lucene99.Lucene99Codec;
|
||||
import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat;
|
||||
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.FieldType;
|
||||
import org.apache.lucene.document.KnnFloatVectorField;
|
||||
import org.apache.lucene.document.StringField;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.NoMergePolicy;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.tests.analysis.MockAnalyzer;
|
||||
import org.apache.lucene.tests.util.TestUtil;
|
||||
import org.apache.lucene.util.Version;
|
||||
|
||||
public class TestInt8HnswBackwardsCompatibility extends BackwardsCompatibilityTestBase {
|
||||
|
||||
static final String INDEX_NAME = "int8_hnsw";
|
||||
static final String SUFFIX = "";
|
||||
private static final Version FIRST_INT8_HNSW_VERSION = Version.LUCENE_9_10_1;
|
||||
private static final String KNN_VECTOR_FIELD = "knn_field";
|
||||
private static final int DOC_COUNT = 30;
|
||||
private static final FieldType KNN_VECTOR_FIELD_TYPE =
|
||||
KnnFloatVectorField.createFieldType(3, VectorSimilarityFunction.COSINE);
|
||||
private static final float[] KNN_VECTOR = {0.2f, -0.1f, 0.1f};
|
||||
|
||||
public TestInt8HnswBackwardsCompatibility(Version version, String pattern) {
|
||||
super(version, pattern);
|
||||
}
|
||||
|
||||
/** Provides all sorted versions to the test-framework */
|
||||
@ParametersFactory(argumentFormatting = "Lucene-Version:%1$s; Pattern: %2$s")
|
||||
public static Iterable<Object[]> testVersionsFactory() throws IllegalAccessException {
|
||||
return allVersion(INDEX_NAME, SUFFIX);
|
||||
}
|
||||
|
||||
protected Codec getCodec() {
|
||||
return new Lucene99Codec() {
|
||||
@Override
|
||||
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||
return new Lucene99HnswScalarQuantizedVectorsFormat(
|
||||
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
|
||||
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsVersion(Version version) {
|
||||
return version.onOrAfter(FIRST_INT8_HNSW_VERSION);
|
||||
}
|
||||
|
||||
@Override
|
||||
void verifyUsesDefaultCodec(Directory dir, String name) throws IOException {
|
||||
// We don't use the default codec
|
||||
}
|
||||
|
||||
public void testInt8HnswIndexAndSearch() throws Exception {
|
||||
IndexWriterConfig indexWriterConfig =
|
||||
newIndexWriterConfig(new MockAnalyzer(random()))
|
||||
.setOpenMode(IndexWriterConfig.OpenMode.APPEND)
|
||||
.setCodec(getCodec())
|
||||
.setMergePolicy(newLogMergePolicy());
|
||||
try (IndexWriter writer = new IndexWriter(directory, indexWriterConfig)) {
|
||||
// add 10 docs
|
||||
for (int i = 0; i < 10; i++) {
|
||||
writer.addDocument(knnDocument(i + DOC_COUNT));
|
||||
if (random().nextBoolean()) {
|
||||
writer.flush();
|
||||
}
|
||||
}
|
||||
if (random().nextBoolean()) {
|
||||
writer.forceMerge(1);
|
||||
}
|
||||
writer.commit();
|
||||
try (IndexReader reader = DirectoryReader.open(directory)) {
|
||||
IndexSearcher searcher = new IndexSearcher(reader);
|
||||
assertKNNSearch(searcher, KNN_VECTOR, 1000, DOC_COUNT + 10, "0");
|
||||
assertKNNSearch(searcher, KNN_VECTOR, 10, 10, "0");
|
||||
}
|
||||
}
|
||||
// This will confirm the docs are really sorted
|
||||
TestUtil.checkIndex(directory);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void createIndex(Directory dir) throws IOException {
|
||||
IndexWriterConfig conf =
|
||||
new IndexWriterConfig(new MockAnalyzer(random()))
|
||||
.setMaxBufferedDocs(10)
|
||||
.setCodec(TestUtil.getDefaultCodec())
|
||||
.setMergePolicy(NoMergePolicy.INSTANCE);
|
||||
try (IndexWriter writer = new IndexWriter(dir, conf)) {
|
||||
for (int i = 0; i < DOC_COUNT; i++) {
|
||||
writer.addDocument(knnDocument(i));
|
||||
}
|
||||
writer.forceMerge(1);
|
||||
}
|
||||
try (DirectoryReader reader = DirectoryReader.open(dir)) {
|
||||
IndexSearcher searcher = new IndexSearcher(reader);
|
||||
assertKNNSearch(searcher, KNN_VECTOR, 1000, DOC_COUNT, "0");
|
||||
assertKNNSearch(searcher, KNN_VECTOR, 10, 10, "0");
|
||||
}
|
||||
}
|
||||
|
||||
private static Document knnDocument(int id) {
|
||||
Document doc = new Document();
|
||||
float[] vector = {KNN_VECTOR[0], KNN_VECTOR[1], KNN_VECTOR[2] + 0.1f * id};
|
||||
doc.add(new KnnFloatVectorField(KNN_VECTOR_FIELD, vector, KNN_VECTOR_FIELD_TYPE));
|
||||
doc.add(new StringField("id", Integer.toString(id), Field.Store.YES));
|
||||
return doc;
|
||||
}
|
||||
|
||||
public void testReadOldIndices() throws Exception {
|
||||
try (DirectoryReader reader = DirectoryReader.open(directory)) {
|
||||
IndexSearcher searcher = new IndexSearcher(reader);
|
||||
assertKNNSearch(searcher, KNN_VECTOR, 1000, DOC_COUNT, "0");
|
||||
assertKNNSearch(searcher, KNN_VECTOR, 10, 10, "0");
|
||||
}
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -36,6 +36,8 @@ public class VectorUtilBenchmark {
|
|||
|
||||
private byte[] bytesA;
|
||||
private byte[] bytesB;
|
||||
private byte[] halfBytesA;
|
||||
private byte[] halfBytesB;
|
||||
private float[] floatsA;
|
||||
private float[] floatsB;
|
||||
|
||||
|
@ -51,6 +53,14 @@ public class VectorUtilBenchmark {
|
|||
bytesB = new byte[size];
|
||||
random.nextBytes(bytesA);
|
||||
random.nextBytes(bytesB);
|
||||
// random half byte arrays for binary methods
|
||||
// this means that all values must be between 0 and 15
|
||||
halfBytesA = new byte[size];
|
||||
halfBytesB = new byte[size];
|
||||
for (int i = 0; i < size; ++i) {
|
||||
halfBytesA[i] = (byte) random.nextInt(16);
|
||||
halfBytesB[i] = (byte) random.nextInt(16);
|
||||
}
|
||||
|
||||
// random float arrays for float methods
|
||||
floatsA = new float[size];
|
||||
|
@ -94,6 +104,17 @@ public class VectorUtilBenchmark {
|
|||
return VectorUtil.squareDistance(bytesA, bytesB);
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public int binaryHalfByteScalar() {
|
||||
return VectorUtil.int4DotProduct(halfBytesA, halfBytesB);
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
|
||||
public int binaryHalfByteVector() {
|
||||
return VectorUtil.int4DotProduct(halfBytesA, halfBytesB);
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public float floatCosineScalar() {
|
||||
return VectorUtil.cosine(floatsA, floatsB);
|
||||
|
|
|
@ -65,7 +65,7 @@ public final class Lucene99HnswScalarQuantizedVectorsFormat extends KnnVectorsFo
|
|||
|
||||
/** Constructs a format using default graph construction parameters */
|
||||
public Lucene99HnswScalarQuantizedVectorsFormat() {
|
||||
this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null, null);
|
||||
this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, 7, true, null, null);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -75,7 +75,7 @@ public final class Lucene99HnswScalarQuantizedVectorsFormat extends KnnVectorsFo
|
|||
* @param beamWidth the size of the queue maintained during graph construction.
|
||||
*/
|
||||
public Lucene99HnswScalarQuantizedVectorsFormat(int maxConn, int beamWidth) {
|
||||
this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null, null);
|
||||
this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, 7, true, null, null);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -85,6 +85,11 @@ public final class Lucene99HnswScalarQuantizedVectorsFormat extends KnnVectorsFo
|
|||
* @param beamWidth the size of the queue maintained during graph construction.
|
||||
* @param numMergeWorkers number of workers (threads) that will be used when doing merge. If
|
||||
* larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec
|
||||
* @param bits the number of bits to use for scalar quantization (must be between 1 and 8,
|
||||
* inclusive)
|
||||
* @param compress whether to compress the vectors, if true, the vectors that are quantized with
|
||||
* lte 4 bits will be compressed into a single byte. If false, the vectors will be stored as
|
||||
* is. This provides a trade-off of memory usage and speed.
|
||||
* @param confidenceInterval the confidenceInterval for scalar quantizing the vectors, when `null`
|
||||
* it is calculated based on the vector field dimensions.
|
||||
* @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are
|
||||
|
@ -94,6 +99,8 @@ public final class Lucene99HnswScalarQuantizedVectorsFormat extends KnnVectorsFo
|
|||
int maxConn,
|
||||
int beamWidth,
|
||||
int numMergeWorkers,
|
||||
int bits,
|
||||
boolean compress,
|
||||
Float confidenceInterval,
|
||||
ExecutorService mergeExec) {
|
||||
super("Lucene99HnswScalarQuantizedVectorsFormat");
|
||||
|
@ -127,7 +134,8 @@ public final class Lucene99HnswScalarQuantizedVectorsFormat extends KnnVectorsFo
|
|||
} else {
|
||||
this.mergeExec = null;
|
||||
}
|
||||
this.flatVectorsFormat = new Lucene99ScalarQuantizedVectorsFormat(confidenceInterval);
|
||||
this.flatVectorsFormat =
|
||||
new Lucene99ScalarQuantizedVectorsFormat(confidenceInterval, bits, compress);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -30,12 +30,17 @@ import org.apache.lucene.index.SegmentWriteState;
|
|||
* @lucene.experimental
|
||||
*/
|
||||
public final class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
|
||||
|
||||
// The bits that are allowed for scalar quantization
|
||||
// We only allow unsigned byte (8), signed byte (7), and half-byte (4)
|
||||
private static final int ALLOWED_BITS = (1 << 8) | (1 << 7) | (1 << 4);
|
||||
public static final String QUANTIZED_VECTOR_COMPONENT = "QVEC";
|
||||
|
||||
static final String NAME = "Lucene99ScalarQuantizedVectorsFormat";
|
||||
|
||||
static final int VERSION_START = 0;
|
||||
static final int VERSION_CURRENT = VERSION_START;
|
||||
static final int VERSION_ADD_BITS = 1;
|
||||
static final int VERSION_CURRENT = VERSION_ADD_BITS;
|
||||
static final String META_CODEC_NAME = "Lucene99ScalarQuantizedVectorsFormatMeta";
|
||||
static final String VECTOR_DATA_CODEC_NAME = "Lucene99ScalarQuantizedVectorsFormatData";
|
||||
static final String META_EXTENSION = "vemq";
|
||||
|
@ -55,18 +60,27 @@ public final class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsForma
|
|||
*/
|
||||
final Float confidenceInterval;
|
||||
|
||||
final byte bits;
|
||||
final boolean compress;
|
||||
|
||||
/** Constructs a format using default graph construction parameters */
|
||||
public Lucene99ScalarQuantizedVectorsFormat() {
|
||||
this(null);
|
||||
this(null, 7, true);
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a format using the given graph construction parameters.
|
||||
*
|
||||
* @param confidenceInterval the confidenceInterval for scalar quantizing the vectors, when `null`
|
||||
* it is calculated based on the vector field dimensions.
|
||||
* it is calculated dynamically.
|
||||
* @param bits the number of bits to use for scalar quantization (must be between 1 and 8,
|
||||
* inclusive)
|
||||
* @param compress whether to compress the vectors, if true, the vectors that are quantized with
|
||||
* lte 4 bits will be compressed into a single byte. If false, the vectors will be stored as
|
||||
* is. This provides a trade-off of memory usage and speed.
|
||||
*/
|
||||
public Lucene99ScalarQuantizedVectorsFormat(Float confidenceInterval) {
|
||||
public Lucene99ScalarQuantizedVectorsFormat(
|
||||
Float confidenceInterval, int bits, boolean compress) {
|
||||
if (confidenceInterval != null
|
||||
&& (confidenceInterval < MINIMUM_CONFIDENCE_INTERVAL
|
||||
|| confidenceInterval > MAXIMUM_CONFIDENCE_INTERVAL)) {
|
||||
|
@ -78,7 +92,12 @@ public final class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsForma
|
|||
+ "; confidenceInterval="
|
||||
+ confidenceInterval);
|
||||
}
|
||||
if (bits < 1 || bits > 8 || (ALLOWED_BITS & (1 << bits)) == 0) {
|
||||
throw new IllegalArgumentException("bits must be one of: 4, 7, 8; bits=" + bits);
|
||||
}
|
||||
this.bits = (byte) bits;
|
||||
this.confidenceInterval = confidenceInterval;
|
||||
this.compress = compress;
|
||||
}
|
||||
|
||||
public static float calculateDefaultConfidenceInterval(int vectorDimension) {
|
||||
|
@ -92,6 +111,10 @@ public final class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsForma
|
|||
+ NAME
|
||||
+ ", confidenceInterval="
|
||||
+ confidenceInterval
|
||||
+ ", bits="
|
||||
+ bits
|
||||
+ ", compress="
|
||||
+ compress
|
||||
+ ", rawVectorFormat="
|
||||
+ rawVectorFormat
|
||||
+ ")";
|
||||
|
@ -100,7 +123,7 @@ public final class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsForma
|
|||
@Override
|
||||
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
|
||||
return new Lucene99ScalarQuantizedVectorsWriter(
|
||||
state, confidenceInterval, rawVectorFormat.fieldsWriter(state));
|
||||
state, confidenceInterval, bits, compress, rawVectorFormat.fieldsWriter(state));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -82,7 +82,7 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
|
|||
Lucene99ScalarQuantizedVectorsFormat.VERSION_CURRENT,
|
||||
state.segmentInfo.getId(),
|
||||
state.segmentSuffix);
|
||||
readFields(meta, state.fieldInfos);
|
||||
readFields(meta, versionMeta, state.fieldInfos);
|
||||
} catch (Throwable exception) {
|
||||
priorE = exception;
|
||||
} finally {
|
||||
|
@ -102,13 +102,14 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
|
|||
}
|
||||
}
|
||||
|
||||
private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException {
|
||||
private void readFields(ChecksumIndexInput meta, int versionMeta, FieldInfos infos)
|
||||
throws IOException {
|
||||
for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) {
|
||||
FieldInfo info = infos.fieldInfo(fieldNumber);
|
||||
if (info == null) {
|
||||
throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
|
||||
}
|
||||
FieldEntry fieldEntry = readField(meta, info);
|
||||
FieldEntry fieldEntry = readField(meta, versionMeta, info);
|
||||
validateFieldEntry(info, fieldEntry);
|
||||
fields.put(info.name, fieldEntry);
|
||||
}
|
||||
|
@ -126,8 +127,13 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
|
|||
+ fieldEntry.dimension);
|
||||
}
|
||||
|
||||
// int8 quantized and calculated stored offset.
|
||||
long quantizedVectorBytes = dimension + Float.BYTES;
|
||||
final long quantizedVectorBytes;
|
||||
if (fieldEntry.bits <= 4 && fieldEntry.compress) {
|
||||
quantizedVectorBytes = ((dimension + 1) >> 1) + Float.BYTES;
|
||||
} else {
|
||||
// int8 quantized and calculated stored offset.
|
||||
quantizedVectorBytes = dimension + Float.BYTES;
|
||||
}
|
||||
long numQuantizedVectorBytes = Math.multiplyExact(quantizedVectorBytes, fieldEntry.size);
|
||||
if (numQuantizedVectorBytes != fieldEntry.vectorDataLength) {
|
||||
throw new IllegalStateException(
|
||||
|
@ -209,6 +215,8 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
|
|||
fieldEntry.ordToDoc,
|
||||
fieldEntry.dimension,
|
||||
fieldEntry.size,
|
||||
fieldEntry.bits,
|
||||
fieldEntry.compress,
|
||||
fieldEntry.vectorDataOffset,
|
||||
fieldEntry.vectorDataLength,
|
||||
quantizedVectorData);
|
||||
|
@ -236,7 +244,8 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
|
|||
return size;
|
||||
}
|
||||
|
||||
private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException {
|
||||
private FieldEntry readField(IndexInput input, int versionMeta, FieldInfo info)
|
||||
throws IOException {
|
||||
VectorEncoding vectorEncoding = readVectorEncoding(input);
|
||||
VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
|
||||
if (similarityFunction != info.getVectorSimilarityFunction()) {
|
||||
|
@ -248,7 +257,7 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
|
|||
+ " != "
|
||||
+ info.getVectorSimilarityFunction());
|
||||
}
|
||||
return new FieldEntry(input, vectorEncoding, info.getVectorSimilarityFunction());
|
||||
return new FieldEntry(input, versionMeta, vectorEncoding, info.getVectorSimilarityFunction());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -261,6 +270,8 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
|
|||
fieldEntry.ordToDoc,
|
||||
fieldEntry.dimension,
|
||||
fieldEntry.size,
|
||||
fieldEntry.bits,
|
||||
fieldEntry.compress,
|
||||
fieldEntry.vectorDataOffset,
|
||||
fieldEntry.vectorDataLength,
|
||||
quantizedVectorData);
|
||||
|
@ -285,10 +296,13 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
|
|||
final long vectorDataLength;
|
||||
final ScalarQuantizer scalarQuantizer;
|
||||
final int size;
|
||||
final byte bits;
|
||||
final boolean compress;
|
||||
final OrdToDocDISIReaderConfiguration ordToDoc;
|
||||
|
||||
FieldEntry(
|
||||
IndexInput input,
|
||||
int versionMeta,
|
||||
VectorEncoding vectorEncoding,
|
||||
VectorSimilarityFunction similarityFunction)
|
||||
throws IOException {
|
||||
|
@ -299,12 +313,29 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
|
|||
dimension = input.readVInt();
|
||||
size = input.readInt();
|
||||
if (size > 0) {
|
||||
float confidenceInterval = Float.intBitsToFloat(input.readInt());
|
||||
float minQuantile = Float.intBitsToFloat(input.readInt());
|
||||
float maxQuantile = Float.intBitsToFloat(input.readInt());
|
||||
scalarQuantizer = new ScalarQuantizer(minQuantile, maxQuantile, confidenceInterval);
|
||||
if (versionMeta < Lucene99ScalarQuantizedVectorsFormat.VERSION_ADD_BITS) {
|
||||
int floatBits = input.readInt(); // confidenceInterval, unused
|
||||
if (floatBits == -1) {
|
||||
throw new CorruptIndexException(
|
||||
"Missing confidence interval for scalar quantizer", input);
|
||||
}
|
||||
this.bits = (byte) 7;
|
||||
this.compress = false;
|
||||
float minQuantile = Float.intBitsToFloat(input.readInt());
|
||||
float maxQuantile = Float.intBitsToFloat(input.readInt());
|
||||
scalarQuantizer = new ScalarQuantizer(minQuantile, maxQuantile, (byte) 7);
|
||||
} else {
|
||||
input.readInt(); // confidenceInterval, unused
|
||||
this.bits = input.readByte();
|
||||
this.compress = input.readByte() == 1;
|
||||
float minQuantile = Float.intBitsToFloat(input.readInt());
|
||||
float maxQuantile = Float.intBitsToFloat(input.readInt());
|
||||
scalarQuantizer = new ScalarQuantizer(minQuantile, maxQuantile, bits);
|
||||
}
|
||||
} else {
|
||||
scalarQuantizer = null;
|
||||
this.bits = (byte) 7;
|
||||
this.compress = false;
|
||||
}
|
||||
ordToDoc = OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size);
|
||||
}
|
||||
|
|
|
@ -96,12 +96,20 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
private final IndexOutput meta, quantizedVectorData;
|
||||
private final Float confidenceInterval;
|
||||
private final FlatVectorsWriter rawVectorDelegate;
|
||||
private final byte bits;
|
||||
private final boolean compress;
|
||||
private boolean finished;
|
||||
|
||||
public Lucene99ScalarQuantizedVectorsWriter(
|
||||
SegmentWriteState state, Float confidenceInterval, FlatVectorsWriter rawVectorDelegate)
|
||||
SegmentWriteState state,
|
||||
Float confidenceInterval,
|
||||
byte bits,
|
||||
boolean compress,
|
||||
FlatVectorsWriter rawVectorDelegate)
|
||||
throws IOException {
|
||||
this.confidenceInterval = confidenceInterval;
|
||||
this.bits = bits;
|
||||
this.compress = compress;
|
||||
segmentWriteState = state;
|
||||
String metaFileName =
|
||||
IndexFileNames.segmentFileName(
|
||||
|
@ -145,12 +153,21 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
public FlatFieldVectorsWriter<?> addField(
|
||||
FieldInfo fieldInfo, KnnFieldVectorsWriter<?> indexWriter) throws IOException {
|
||||
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
|
||||
float confidenceInterval =
|
||||
this.confidenceInterval == null
|
||||
? calculateDefaultConfidenceInterval(fieldInfo.getVectorDimension())
|
||||
: this.confidenceInterval;
|
||||
if (bits <= 4 && fieldInfo.getVectorDimension() % 2 != 0) {
|
||||
throw new IllegalArgumentException(
|
||||
"bits="
|
||||
+ bits
|
||||
+ " is not supported for odd vector dimensions; vector dimension="
|
||||
+ fieldInfo.getVectorDimension());
|
||||
}
|
||||
FieldWriter quantizedWriter =
|
||||
new FieldWriter(confidenceInterval, fieldInfo, segmentWriteState.infoStream, indexWriter);
|
||||
new FieldWriter(
|
||||
confidenceInterval,
|
||||
bits,
|
||||
compress,
|
||||
fieldInfo,
|
||||
segmentWriteState.infoStream,
|
||||
indexWriter);
|
||||
fields.add(quantizedWriter);
|
||||
indexWriter = quantizedWriter;
|
||||
}
|
||||
|
@ -164,24 +181,23 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
// the vectors directly to the new segment.
|
||||
// No need to use temporary file as we don't have to re-open for reading
|
||||
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
|
||||
ScalarQuantizer mergedQuantizationState = mergeQuantiles(fieldInfo, mergeState);
|
||||
ScalarQuantizer mergedQuantizationState =
|
||||
mergeAndRecalculateQuantiles(mergeState, fieldInfo, confidenceInterval, bits);
|
||||
MergedQuantizedVectorValues byteVectorValues =
|
||||
MergedQuantizedVectorValues.mergeQuantizedByteVectorValues(
|
||||
fieldInfo, mergeState, mergedQuantizationState);
|
||||
long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES);
|
||||
DocsWithFieldSet docsWithField =
|
||||
writeQuantizedVectorData(quantizedVectorData, byteVectorValues);
|
||||
writeQuantizedVectorData(quantizedVectorData, byteVectorValues, bits, compress);
|
||||
long vectorDataLength = quantizedVectorData.getFilePointer() - vectorDataOffset;
|
||||
float confidenceInterval =
|
||||
this.confidenceInterval == null
|
||||
? calculateDefaultConfidenceInterval(fieldInfo.getVectorDimension())
|
||||
: this.confidenceInterval;
|
||||
writeMeta(
|
||||
fieldInfo,
|
||||
segmentWriteState.segmentInfo.maxDoc(),
|
||||
vectorDataOffset,
|
||||
vectorDataLength,
|
||||
confidenceInterval,
|
||||
bits,
|
||||
compress,
|
||||
mergedQuantizationState.getLowerQuantile(),
|
||||
mergedQuantizationState.getUpperQuantile(),
|
||||
docsWithField);
|
||||
|
@ -195,7 +211,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
// Simply merge the underlying delegate, which just copies the raw vector data to a new
|
||||
// segment file
|
||||
rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
|
||||
ScalarQuantizer mergedQuantizationState = mergeQuantiles(fieldInfo, mergeState);
|
||||
ScalarQuantizer mergedQuantizationState =
|
||||
mergeAndRecalculateQuantiles(mergeState, fieldInfo, confidenceInterval, bits);
|
||||
return mergeOneFieldToIndex(
|
||||
segmentWriteState, fieldInfo, mergeState, mergedQuantizationState);
|
||||
}
|
||||
|
@ -255,6 +272,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
vectorDataOffset,
|
||||
vectorDataLength,
|
||||
confidenceInterval,
|
||||
bits,
|
||||
compress,
|
||||
fieldData.minQuantile,
|
||||
fieldData.maxQuantile,
|
||||
fieldData.docsWithField);
|
||||
|
@ -266,6 +285,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
long vectorDataOffset,
|
||||
long vectorDataLength,
|
||||
Float confidenceInterval,
|
||||
byte bits,
|
||||
boolean compress,
|
||||
Float lowerQuantile,
|
||||
Float upperQuantile,
|
||||
DocsWithFieldSet docsWithField)
|
||||
|
@ -280,11 +301,9 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
meta.writeInt(count);
|
||||
if (count > 0) {
|
||||
assert Float.isFinite(lowerQuantile) && Float.isFinite(upperQuantile);
|
||||
meta.writeInt(
|
||||
Float.floatToIntBits(
|
||||
confidenceInterval != null
|
||||
? confidenceInterval
|
||||
: calculateDefaultConfidenceInterval(field.getVectorDimension())));
|
||||
meta.writeInt(confidenceInterval == null ? -1 : Float.floatToIntBits(confidenceInterval));
|
||||
meta.writeByte(bits);
|
||||
meta.writeByte(compress ? (byte) 1 : (byte) 0);
|
||||
meta.writeInt(Float.floatToIntBits(lowerQuantile));
|
||||
meta.writeInt(Float.floatToIntBits(upperQuantile));
|
||||
}
|
||||
|
@ -296,6 +315,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
private void writeQuantizedVectors(FieldWriter fieldData) throws IOException {
|
||||
ScalarQuantizer scalarQuantizer = fieldData.createQuantizer();
|
||||
byte[] vector = new byte[fieldData.fieldInfo.getVectorDimension()];
|
||||
byte[] compressedVector =
|
||||
fieldData.compress
|
||||
? OffHeapQuantizedByteVectorValues.compressedArray(
|
||||
fieldData.fieldInfo.getVectorDimension(), bits)
|
||||
: null;
|
||||
final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||
float[] copy = fieldData.normalize ? new float[fieldData.fieldInfo.getVectorDimension()] : null;
|
||||
for (float[] v : fieldData.floatVectors) {
|
||||
|
@ -307,7 +331,12 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
|
||||
float offsetCorrection =
|
||||
scalarQuantizer.quantize(v, vector, fieldData.fieldInfo.getVectorSimilarityFunction());
|
||||
quantizedVectorData.writeBytes(vector, vector.length);
|
||||
if (compressedVector != null) {
|
||||
OffHeapQuantizedByteVectorValues.compressBytes(vector, compressedVector);
|
||||
quantizedVectorData.writeBytes(compressedVector, compressedVector.length);
|
||||
} else {
|
||||
quantizedVectorData.writeBytes(vector, vector.length);
|
||||
}
|
||||
offsetBuffer.putFloat(offsetCorrection);
|
||||
quantizedVectorData.writeBytes(offsetBuffer.array(), offsetBuffer.array().length);
|
||||
offsetBuffer.rewind();
|
||||
|
@ -348,6 +377,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
vectorDataOffset,
|
||||
quantizedVectorLength,
|
||||
confidenceInterval,
|
||||
bits,
|
||||
compress,
|
||||
fieldData.minQuantile,
|
||||
fieldData.maxQuantile,
|
||||
newDocsWithField);
|
||||
|
@ -356,6 +387,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
private void writeSortedQuantizedVectors(FieldWriter fieldData, int[] ordMap) throws IOException {
|
||||
ScalarQuantizer scalarQuantizer = fieldData.createQuantizer();
|
||||
byte[] vector = new byte[fieldData.fieldInfo.getVectorDimension()];
|
||||
byte[] compressedVector =
|
||||
fieldData.compress
|
||||
? OffHeapQuantizedByteVectorValues.compressedArray(
|
||||
fieldData.fieldInfo.getVectorDimension(), bits)
|
||||
: null;
|
||||
final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||
float[] copy = fieldData.normalize ? new float[fieldData.fieldInfo.getVectorDimension()] : null;
|
||||
for (int ordinal : ordMap) {
|
||||
|
@ -367,29 +403,35 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
}
|
||||
float offsetCorrection =
|
||||
scalarQuantizer.quantize(v, vector, fieldData.fieldInfo.getVectorSimilarityFunction());
|
||||
quantizedVectorData.writeBytes(vector, vector.length);
|
||||
if (compressedVector != null) {
|
||||
OffHeapQuantizedByteVectorValues.compressBytes(vector, compressedVector);
|
||||
quantizedVectorData.writeBytes(compressedVector, compressedVector.length);
|
||||
} else {
|
||||
quantizedVectorData.writeBytes(vector, vector.length);
|
||||
}
|
||||
offsetBuffer.putFloat(offsetCorrection);
|
||||
quantizedVectorData.writeBytes(offsetBuffer.array(), offsetBuffer.array().length);
|
||||
offsetBuffer.rewind();
|
||||
}
|
||||
}
|
||||
|
||||
private ScalarQuantizer mergeQuantiles(FieldInfo fieldInfo, MergeState mergeState)
|
||||
throws IOException {
|
||||
assert fieldInfo.getVectorEncoding() == VectorEncoding.FLOAT32;
|
||||
float confidenceInterval =
|
||||
this.confidenceInterval == null
|
||||
? calculateDefaultConfidenceInterval(fieldInfo.getVectorDimension())
|
||||
: this.confidenceInterval;
|
||||
return mergeAndRecalculateQuantiles(mergeState, fieldInfo, confidenceInterval);
|
||||
}
|
||||
|
||||
private ScalarQuantizedCloseableRandomVectorScorerSupplier mergeOneFieldToIndex(
|
||||
SegmentWriteState segmentWriteState,
|
||||
FieldInfo fieldInfo,
|
||||
MergeState mergeState,
|
||||
ScalarQuantizer mergedQuantizationState)
|
||||
throws IOException {
|
||||
if (segmentWriteState.infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) {
|
||||
segmentWriteState.infoStream.message(
|
||||
QUANTIZED_VECTOR_COMPONENT,
|
||||
"quantized field="
|
||||
+ " confidenceInterval="
|
||||
+ confidenceInterval
|
||||
+ " minQuantile="
|
||||
+ mergedQuantizationState.getLowerQuantile()
|
||||
+ " maxQuantile="
|
||||
+ mergedQuantizationState.getUpperQuantile());
|
||||
}
|
||||
long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES);
|
||||
IndexOutput tempQuantizedVectorData =
|
||||
segmentWriteState.directory.createTempOutput(
|
||||
|
@ -401,7 +443,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
MergedQuantizedVectorValues.mergeQuantizedByteVectorValues(
|
||||
fieldInfo, mergeState, mergedQuantizationState);
|
||||
DocsWithFieldSet docsWithField =
|
||||
writeQuantizedVectorData(tempQuantizedVectorData, byteVectorValues);
|
||||
writeQuantizedVectorData(tempQuantizedVectorData, byteVectorValues, bits, compress);
|
||||
CodecUtil.writeFooter(tempQuantizedVectorData);
|
||||
IOUtils.close(tempQuantizedVectorData);
|
||||
quantizationDataInput =
|
||||
|
@ -421,6 +463,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
vectorDataOffset,
|
||||
vectorDataLength,
|
||||
confidenceInterval,
|
||||
bits,
|
||||
compress,
|
||||
mergedQuantizationState.getLowerQuantile(),
|
||||
mergedQuantizationState.getUpperQuantile(),
|
||||
docsWithField);
|
||||
|
@ -438,6 +482,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
new OffHeapQuantizedByteVectorValues.DenseOffHeapVectorValues(
|
||||
fieldInfo.getVectorDimension(),
|
||||
docsWithField.cardinality(),
|
||||
bits,
|
||||
compress,
|
||||
quantizationDataInput)));
|
||||
} finally {
|
||||
if (success == false) {
|
||||
|
@ -449,9 +495,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
}
|
||||
|
||||
static ScalarQuantizer mergeQuantiles(
|
||||
List<ScalarQuantizer> quantizationStates,
|
||||
List<Integer> segmentSizes,
|
||||
float confidenceInterval) {
|
||||
List<ScalarQuantizer> quantizationStates, List<Integer> segmentSizes, byte bits) {
|
||||
assert quantizationStates.size() == segmentSizes.size();
|
||||
if (quantizationStates.isEmpty()) {
|
||||
return null;
|
||||
|
@ -466,10 +510,13 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
lowerQuantile += quantizationStates.get(i).getLowerQuantile() * segmentSizes.get(i);
|
||||
upperQuantile += quantizationStates.get(i).getUpperQuantile() * segmentSizes.get(i);
|
||||
totalCount += segmentSizes.get(i);
|
||||
if (quantizationStates.get(i).getBits() != bits) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
lowerQuantile /= totalCount;
|
||||
upperQuantile /= totalCount;
|
||||
return new ScalarQuantizer(lowerQuantile, upperQuantile, confidenceInterval);
|
||||
return new ScalarQuantizer(lowerQuantile, upperQuantile, bits);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -531,11 +578,14 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
* @param mergeState The merge state
|
||||
* @param fieldInfo The field info
|
||||
* @param confidenceInterval The confidence interval
|
||||
* @param bits The number of bits
|
||||
* @return The merged quantiles
|
||||
* @throws IOException If there is a low-level I/O error
|
||||
*/
|
||||
public static ScalarQuantizer mergeAndRecalculateQuantiles(
|
||||
MergeState mergeState, FieldInfo fieldInfo, float confidenceInterval) throws IOException {
|
||||
MergeState mergeState, FieldInfo fieldInfo, Float confidenceInterval, byte bits)
|
||||
throws IOException {
|
||||
assert fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32);
|
||||
List<ScalarQuantizer> quantizationStates = new ArrayList<>(mergeState.liveDocs.length);
|
||||
List<Integer> segmentSizes = new ArrayList<>(mergeState.liveDocs.length);
|
||||
for (int i = 0; i < mergeState.liveDocs.length; i++) {
|
||||
|
@ -550,14 +600,17 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
segmentSizes.add(fvv.size());
|
||||
}
|
||||
}
|
||||
ScalarQuantizer mergedQuantiles =
|
||||
mergeQuantiles(quantizationStates, segmentSizes, confidenceInterval);
|
||||
ScalarQuantizer mergedQuantiles = mergeQuantiles(quantizationStates, segmentSizes, bits);
|
||||
// Segments no providing quantization state indicates that their quantiles were never
|
||||
// calculated.
|
||||
// To be safe, we should always recalculate given a sample set over all the float vectors in the
|
||||
// merged
|
||||
// segment view
|
||||
if (mergedQuantiles == null || shouldRecomputeQuantiles(mergedQuantiles, quantizationStates)) {
|
||||
if (mergedQuantiles == null
|
||||
// For smaller `bits` values, we should always recalculate the quantiles
|
||||
// TODO: this is very conservative, could we reuse information for even int4 quantization?
|
||||
|| bits <= 4
|
||||
|| shouldRecomputeQuantiles(mergedQuantiles, quantizationStates)) {
|
||||
int numVectors = 0;
|
||||
FloatVectorValues vectorValues =
|
||||
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
|
||||
|
@ -568,10 +621,17 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
numVectors++;
|
||||
}
|
||||
mergedQuantiles =
|
||||
ScalarQuantizer.fromVectors(
|
||||
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState),
|
||||
confidenceInterval,
|
||||
numVectors);
|
||||
confidenceInterval == null
|
||||
? ScalarQuantizer.fromVectorsAutoInterval(
|
||||
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState),
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
numVectors,
|
||||
bits)
|
||||
: ScalarQuantizer.fromVectors(
|
||||
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState),
|
||||
confidenceInterval,
|
||||
numVectors,
|
||||
bits);
|
||||
}
|
||||
return mergedQuantiles;
|
||||
}
|
||||
|
@ -600,8 +660,17 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
* Writes the vector values to the output and returns a set of documents that contains vectors.
|
||||
*/
|
||||
public static DocsWithFieldSet writeQuantizedVectorData(
|
||||
IndexOutput output, QuantizedByteVectorValues quantizedByteVectorValues) throws IOException {
|
||||
IndexOutput output,
|
||||
QuantizedByteVectorValues quantizedByteVectorValues,
|
||||
byte bits,
|
||||
boolean compress)
|
||||
throws IOException {
|
||||
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
|
||||
final byte[] compressedVector =
|
||||
compress
|
||||
? OffHeapQuantizedByteVectorValues.compressedArray(
|
||||
quantizedByteVectorValues.dimension(), bits)
|
||||
: null;
|
||||
for (int docV = quantizedByteVectorValues.nextDoc();
|
||||
docV != NO_MORE_DOCS;
|
||||
docV = quantizedByteVectorValues.nextDoc()) {
|
||||
|
@ -609,7 +678,12 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
byte[] binaryValue = quantizedByteVectorValues.vectorValue();
|
||||
assert binaryValue.length == quantizedByteVectorValues.dimension()
|
||||
: "dim=" + quantizedByteVectorValues.dimension() + " len=" + binaryValue.length;
|
||||
output.writeBytes(binaryValue, binaryValue.length);
|
||||
if (compressedVector != null) {
|
||||
OffHeapQuantizedByteVectorValues.compressBytes(binaryValue, compressedVector);
|
||||
output.writeBytes(compressedVector, compressedVector.length);
|
||||
} else {
|
||||
output.writeBytes(binaryValue, binaryValue.length);
|
||||
}
|
||||
output.writeInt(Float.floatToIntBits(quantizedByteVectorValues.getScoreCorrectionConstant()));
|
||||
docsWithField.add(docV);
|
||||
}
|
||||
|
@ -625,7 +699,9 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
private static final long SHALLOW_SIZE = shallowSizeOfInstance(FieldWriter.class);
|
||||
private final List<float[]> floatVectors;
|
||||
private final FieldInfo fieldInfo;
|
||||
private final float confidenceInterval;
|
||||
private final Float confidenceInterval;
|
||||
private final byte bits;
|
||||
private final boolean compress;
|
||||
private final InfoStream infoStream;
|
||||
private final boolean normalize;
|
||||
private float minQuantile = Float.POSITIVE_INFINITY;
|
||||
|
@ -635,17 +711,21 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
|
||||
@SuppressWarnings("unchecked")
|
||||
FieldWriter(
|
||||
float confidenceInterval,
|
||||
Float confidenceInterval,
|
||||
byte bits,
|
||||
boolean compress,
|
||||
FieldInfo fieldInfo,
|
||||
InfoStream infoStream,
|
||||
KnnFieldVectorsWriter<?> indexWriter) {
|
||||
super((KnnFieldVectorsWriter<float[]>) indexWriter);
|
||||
this.confidenceInterval = confidenceInterval;
|
||||
this.bits = bits;
|
||||
this.fieldInfo = fieldInfo;
|
||||
this.normalize = fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE;
|
||||
this.floatVectors = new ArrayList<>();
|
||||
this.infoStream = infoStream;
|
||||
this.docsWithField = new DocsWithFieldSet();
|
||||
this.compress = compress;
|
||||
}
|
||||
|
||||
void finish() throws IOException {
|
||||
|
@ -656,13 +736,16 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
finished = true;
|
||||
return;
|
||||
}
|
||||
FloatVectorValues floatVectorValues = new FloatVectorWrapper(floatVectors, normalize);
|
||||
ScalarQuantizer quantizer =
|
||||
ScalarQuantizer.fromVectors(
|
||||
new FloatVectorWrapper(
|
||||
floatVectors,
|
||||
fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE),
|
||||
confidenceInterval,
|
||||
floatVectors.size());
|
||||
confidenceInterval == null
|
||||
? ScalarQuantizer.fromVectorsAutoInterval(
|
||||
floatVectorValues,
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
floatVectors.size(),
|
||||
bits)
|
||||
: ScalarQuantizer.fromVectors(
|
||||
floatVectorValues, confidenceInterval, floatVectors.size(), bits);
|
||||
minQuantile = quantizer.getLowerQuantile();
|
||||
maxQuantile = quantizer.getUpperQuantile();
|
||||
if (infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) {
|
||||
|
@ -671,6 +754,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
"quantized field="
|
||||
+ " confidenceInterval="
|
||||
+ confidenceInterval
|
||||
+ " bits="
|
||||
+ bits
|
||||
+ " minQuantile="
|
||||
+ minQuantile
|
||||
+ " maxQuantile="
|
||||
|
@ -681,7 +766,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
|
||||
ScalarQuantizer createQuantizer() {
|
||||
assert finished;
|
||||
return new ScalarQuantizer(minQuantile, maxQuantile, confidenceInterval);
|
||||
return new ScalarQuantizer(minQuantile, maxQuantile, bits);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -765,7 +850,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
}
|
||||
}
|
||||
|
||||
private static class QuantizedByteVectorValueSub extends DocIDMerger.Sub {
|
||||
static class QuantizedByteVectorValueSub extends DocIDMerger.Sub {
|
||||
private final QuantizedByteVectorValues values;
|
||||
|
||||
QuantizedByteVectorValueSub(MergeState.DocMap docMap, QuantizedByteVectorValues values) {
|
||||
|
@ -799,6 +884,10 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
// Or we have never been quantized.
|
||||
if (reader == null
|
||||
|| reader.getQuantizationState(fieldInfo.name) == null
|
||||
// For smaller `bits` values, we should always recalculate the quantiles
|
||||
// TODO: this is very conservative, could we reuse information for even int4
|
||||
// quantization?
|
||||
|| scalarQuantizer.getBits() <= 4
|
||||
|| shouldRequantize(reader.getQuantizationState(fieldInfo.name), scalarQuantizer)) {
|
||||
sub =
|
||||
new QuantizedByteVectorValueSub(
|
||||
|
@ -884,7 +973,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
}
|
||||
}
|
||||
|
||||
private static class QuantizedFloatVectorValues extends QuantizedByteVectorValues {
|
||||
static class QuantizedFloatVectorValues extends QuantizedByteVectorValues {
|
||||
private final FloatVectorValues values;
|
||||
private final ScalarQuantizer quantizer;
|
||||
private final byte[] quantizedVector;
|
||||
|
@ -999,14 +1088,13 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
}
|
||||
}
|
||||
|
||||
private static final class OffsetCorrectedQuantizedByteVectorValues
|
||||
extends QuantizedByteVectorValues {
|
||||
static final class OffsetCorrectedQuantizedByteVectorValues extends QuantizedByteVectorValues {
|
||||
|
||||
private final QuantizedByteVectorValues in;
|
||||
private final VectorSimilarityFunction vectorSimilarityFunction;
|
||||
private final ScalarQuantizer scalarQuantizer, oldScalarQuantizer;
|
||||
|
||||
private OffsetCorrectedQuantizedByteVectorValues(
|
||||
OffsetCorrectedQuantizedByteVectorValues(
|
||||
QuantizedByteVectorValues in,
|
||||
VectorSimilarityFunction vectorSimilarityFunction,
|
||||
ScalarQuantizer scalarQuantizer,
|
||||
|
|
|
@ -36,6 +36,10 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
|
|||
|
||||
protected final int dimension;
|
||||
protected final int size;
|
||||
protected final int numBytes;
|
||||
protected final byte bits;
|
||||
protected final boolean compress;
|
||||
|
||||
protected final IndexInput slice;
|
||||
protected final byte[] binaryValue;
|
||||
protected final ByteBuffer byteBuffer;
|
||||
|
@ -43,11 +47,52 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
|
|||
protected int lastOrd = -1;
|
||||
protected final float[] scoreCorrectionConstant = new float[1];
|
||||
|
||||
OffHeapQuantizedByteVectorValues(int dimension, int size, IndexInput slice) {
|
||||
static void decompressBytes(byte[] compressed, int numBytes) {
|
||||
if (numBytes == compressed.length) {
|
||||
return;
|
||||
}
|
||||
if (numBytes << 1 != compressed.length) {
|
||||
throw new IllegalArgumentException(
|
||||
"numBytes: " + numBytes + " does not match compressed length: " + compressed.length);
|
||||
}
|
||||
for (int i = 0; i < numBytes; ++i) {
|
||||
compressed[numBytes + i] = (byte) (compressed[i] & 0x0F);
|
||||
compressed[i] = (byte) ((compressed[i] & 0xFF) >> 4);
|
||||
}
|
||||
}
|
||||
|
||||
static byte[] compressedArray(int dimension, byte bits) {
|
||||
if (bits <= 4) {
|
||||
return new byte[(dimension + 1) >> 1];
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
static void compressBytes(byte[] raw, byte[] compressed) {
|
||||
if (compressed.length != ((raw.length + 1) >> 1)) {
|
||||
throw new IllegalArgumentException(
|
||||
"compressed length: " + compressed.length + " does not match raw length: " + raw.length);
|
||||
}
|
||||
for (int i = 0; i < compressed.length; ++i) {
|
||||
int v = (raw[i] << 4) | raw[compressed.length + i];
|
||||
compressed[i] = (byte) v;
|
||||
}
|
||||
}
|
||||
|
||||
OffHeapQuantizedByteVectorValues(
|
||||
int dimension, int size, byte bits, boolean compress, IndexInput slice) {
|
||||
this.dimension = dimension;
|
||||
this.size = size;
|
||||
this.slice = slice;
|
||||
this.byteSize = dimension + Float.BYTES;
|
||||
this.bits = bits;
|
||||
this.compress = compress;
|
||||
if (bits <= 4 && compress) {
|
||||
this.numBytes = (dimension + 1) >> 1;
|
||||
} else {
|
||||
this.numBytes = dimension;
|
||||
}
|
||||
this.byteSize = this.numBytes + Float.BYTES;
|
||||
byteBuffer = ByteBuffer.allocate(dimension);
|
||||
binaryValue = byteBuffer.array();
|
||||
}
|
||||
|
@ -68,8 +113,9 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
|
|||
return binaryValue;
|
||||
}
|
||||
slice.seek((long) targetOrd * byteSize);
|
||||
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), dimension);
|
||||
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), numBytes);
|
||||
slice.readFloats(scoreCorrectionConstant, 0, 1);
|
||||
decompressBytes(binaryValue, numBytes);
|
||||
lastOrd = targetOrd;
|
||||
return binaryValue;
|
||||
}
|
||||
|
@ -83,6 +129,8 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
|
|||
OrdToDocDISIReaderConfiguration configuration,
|
||||
int dimension,
|
||||
int size,
|
||||
byte bits,
|
||||
boolean compress,
|
||||
long quantizedVectorDataOffset,
|
||||
long quantizedVectorDataLength,
|
||||
IndexInput vectorData)
|
||||
|
@ -94,9 +142,10 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
|
|||
vectorData.slice(
|
||||
"quantized-vector-data", quantizedVectorDataOffset, quantizedVectorDataLength);
|
||||
if (configuration.isDense()) {
|
||||
return new DenseOffHeapVectorValues(dimension, size, bytesSlice);
|
||||
return new DenseOffHeapVectorValues(dimension, size, bits, compress, bytesSlice);
|
||||
} else {
|
||||
return new SparseOffHeapVectorValues(configuration, dimension, size, vectorData, bytesSlice);
|
||||
return new SparseOffHeapVectorValues(
|
||||
configuration, dimension, size, bits, compress, vectorData, bytesSlice);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -108,8 +157,9 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
|
|||
|
||||
private int doc = -1;
|
||||
|
||||
public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice) {
|
||||
super(dimension, size, slice);
|
||||
public DenseOffHeapVectorValues(
|
||||
int dimension, int size, byte bits, boolean compress, IndexInput slice) {
|
||||
super(dimension, size, bits, compress, slice);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -138,7 +188,7 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
|
|||
|
||||
@Override
|
||||
public DenseOffHeapVectorValues copy() throws IOException {
|
||||
return new DenseOffHeapVectorValues(dimension, size, slice.clone());
|
||||
return new DenseOffHeapVectorValues(dimension, size, bits, compress, slice.clone());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -158,10 +208,12 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
|
|||
OrdToDocDISIReaderConfiguration configuration,
|
||||
int dimension,
|
||||
int size,
|
||||
byte bits,
|
||||
boolean compress,
|
||||
IndexInput dataIn,
|
||||
IndexInput slice)
|
||||
throws IOException {
|
||||
super(dimension, size, slice);
|
||||
super(dimension, size, bits, compress, slice);
|
||||
this.configuration = configuration;
|
||||
this.dataIn = dataIn;
|
||||
this.ordToDoc = configuration.getDirectMonotonicReader(dataIn);
|
||||
|
@ -191,7 +243,8 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
|
|||
|
||||
@Override
|
||||
public SparseOffHeapVectorValues copy() throws IOException {
|
||||
return new SparseOffHeapVectorValues(configuration, dimension, size, dataIn, slice.clone());
|
||||
return new SparseOffHeapVectorValues(
|
||||
configuration, dimension, size, bits, compress, dataIn, slice.clone());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -221,7 +274,7 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
|
|||
private static class EmptyOffHeapVectorValues extends OffHeapQuantizedByteVectorValues {
|
||||
|
||||
public EmptyOffHeapVectorValues(int dimension) {
|
||||
super(dimension, 0, null);
|
||||
super(dimension, 0, (byte) 7, false, null);
|
||||
}
|
||||
|
||||
private int doc = -1;
|
||||
|
|
|
@ -151,6 +151,11 @@ final class DefaultVectorUtilSupport implements VectorUtilSupport {
|
|||
return total;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int int4DotProduct(byte[] a, byte[] b) {
|
||||
return dotProduct(a, b);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float cosine(byte[] a, byte[] b) {
|
||||
// Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14.
|
||||
|
|
|
@ -36,6 +36,9 @@ public interface VectorUtilSupport {
|
|||
/** Returns the dot product computed over signed bytes. */
|
||||
int dotProduct(byte[] a, byte[] b);
|
||||
|
||||
/** Returns the dot product over the computed bytes, assuming the values are int4 encoded. */
|
||||
int int4DotProduct(byte[] a, byte[] b);
|
||||
|
||||
/** Returns the cosine similarity between the two byte vectors. */
|
||||
float cosine(byte[] a, byte[] b);
|
||||
|
||||
|
|
|
@ -175,6 +175,13 @@ public final class VectorUtil {
|
|||
return IMPL.dotProduct(a, b);
|
||||
}
|
||||
|
||||
public static int int4DotProduct(byte[] a, byte[] b) {
|
||||
if (a.length != b.length) {
|
||||
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
|
||||
}
|
||||
return IMPL.int4DotProduct(a, b);
|
||||
}
|
||||
|
||||
/**
|
||||
* Dot product score computed over signed bytes, scaled to be in [0, 1].
|
||||
*
|
||||
|
|
|
@ -76,7 +76,7 @@ public class ScalarQuantizedRandomVectorScorer
|
|||
this.queryOffset = correction;
|
||||
this.similarity =
|
||||
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
|
||||
similarityFunction, scalarQuantizer.getConstantMultiplier());
|
||||
similarityFunction, scalarQuantizer.getConstantMultiplier(), scalarQuantizer.getBits());
|
||||
this.values = values;
|
||||
}
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ public class ScalarQuantizedRandomVectorScorerSupplier implements RandomVectorSc
|
|||
RandomAccessQuantizedByteVectorValues values) {
|
||||
this.similarity =
|
||||
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
|
||||
similarityFunction, scalarQuantizer.getConstantMultiplier());
|
||||
similarityFunction, scalarQuantizer.getConstantMultiplier(), scalarQuantizer.getBits());
|
||||
this.values = values;
|
||||
}
|
||||
|
||||
|
|
|
@ -33,14 +33,17 @@ public interface ScalarQuantizedVectorSimilarity {
|
|||
*
|
||||
* @param sim similarity function
|
||||
* @param constMultiplier constant multiplier used for quantization
|
||||
* @param bits number of bits used for quantization
|
||||
* @return a {@link ScalarQuantizedVectorSimilarity} that applies the appropriate corrections
|
||||
*/
|
||||
static ScalarQuantizedVectorSimilarity fromVectorSimilarity(
|
||||
VectorSimilarityFunction sim, float constMultiplier) {
|
||||
VectorSimilarityFunction sim, float constMultiplier, byte bits) {
|
||||
return switch (sim) {
|
||||
case EUCLIDEAN -> new Euclidean(constMultiplier);
|
||||
case COSINE, DOT_PRODUCT -> new DotProduct(constMultiplier);
|
||||
case MAXIMUM_INNER_PRODUCT -> new MaximumInnerProduct(constMultiplier);
|
||||
case COSINE, DOT_PRODUCT -> new DotProduct(
|
||||
constMultiplier, bits <= 4 ? VectorUtil::int4DotProduct : VectorUtil::dotProduct);
|
||||
case MAXIMUM_INNER_PRODUCT -> new MaximumInnerProduct(
|
||||
constMultiplier, bits <= 4 ? VectorUtil::int4DotProduct : VectorUtil::dotProduct);
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -66,15 +69,17 @@ public interface ScalarQuantizedVectorSimilarity {
|
|||
/** Calculates dot product on quantized vectors, applying the appropriate corrections */
|
||||
class DotProduct implements ScalarQuantizedVectorSimilarity {
|
||||
private final float constMultiplier;
|
||||
private final ByteVectorComparator comparator;
|
||||
|
||||
public DotProduct(float constMultiplier) {
|
||||
public DotProduct(float constMultiplier, ByteVectorComparator comparator) {
|
||||
this.constMultiplier = constMultiplier;
|
||||
this.comparator = comparator;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score(
|
||||
byte[] queryVector, float queryOffset, byte[] storedVector, float vectorOffset) {
|
||||
int dotProduct = VectorUtil.dotProduct(storedVector, queryVector);
|
||||
int dotProduct = comparator.compare(storedVector, queryVector);
|
||||
float adjustedDistance = dotProduct * constMultiplier + queryOffset + vectorOffset;
|
||||
return (1 + adjustedDistance) / 2;
|
||||
}
|
||||
|
@ -83,17 +88,24 @@ public interface ScalarQuantizedVectorSimilarity {
|
|||
/** Calculates max inner product on quantized vectors, applying the appropriate corrections */
|
||||
class MaximumInnerProduct implements ScalarQuantizedVectorSimilarity {
|
||||
private final float constMultiplier;
|
||||
private final ByteVectorComparator comparator;
|
||||
|
||||
public MaximumInnerProduct(float constMultiplier) {
|
||||
public MaximumInnerProduct(float constMultiplier, ByteVectorComparator comparator) {
|
||||
this.constMultiplier = constMultiplier;
|
||||
this.comparator = comparator;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score(
|
||||
byte[] queryVector, float queryOffset, byte[] storedVector, float vectorOffset) {
|
||||
int dotProduct = VectorUtil.dotProduct(storedVector, queryVector);
|
||||
int dotProduct = comparator.compare(storedVector, queryVector);
|
||||
float adjustedDistance = dotProduct * constMultiplier + queryOffset + vectorOffset;
|
||||
return scaleMaxInnerProductScore(adjustedDistance);
|
||||
}
|
||||
}
|
||||
|
||||
/** Compares two byte vectors */
|
||||
interface ByteVectorComparator {
|
||||
int compare(byte[] v1, byte[] v2);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,11 +19,15 @@ package org.apache.lucene.util.quantization;
|
|||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
import java.util.stream.IntStream;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.HitQueue;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.util.IntroSelector;
|
||||
import org.apache.lucene.util.Selector;
|
||||
|
||||
|
@ -74,20 +78,23 @@ public class ScalarQuantizer {
|
|||
|
||||
private final float alpha;
|
||||
private final float scale;
|
||||
private final float minQuantile, maxQuantile, confidenceInterval;
|
||||
private final byte bits;
|
||||
private final float minQuantile, maxQuantile;
|
||||
|
||||
/**
|
||||
* @param minQuantile the lower quantile of the distribution
|
||||
* @param maxQuantile the upper quantile of the distribution
|
||||
* @param confidenceInterval The configured confidence interval used to calculate the quantiles.
|
||||
* @param bits the number of bits to use for quantization
|
||||
*/
|
||||
public ScalarQuantizer(float minQuantile, float maxQuantile, float confidenceInterval) {
|
||||
public ScalarQuantizer(float minQuantile, float maxQuantile, byte bits) {
|
||||
assert maxQuantile >= minQuantile;
|
||||
assert bits > 0 && bits <= 8;
|
||||
this.minQuantile = minQuantile;
|
||||
this.maxQuantile = maxQuantile;
|
||||
this.scale = 127f / (maxQuantile - minQuantile);
|
||||
this.alpha = (maxQuantile - minQuantile) / 127f;
|
||||
this.confidenceInterval = confidenceInterval;
|
||||
this.bits = bits;
|
||||
final float divisor = (float) ((1 << bits) - 1);
|
||||
this.scale = divisor / (maxQuantile - minQuantile);
|
||||
this.alpha = (maxQuantile - minQuantile) / divisor;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -100,31 +107,38 @@ public class ScalarQuantizer {
|
|||
*/
|
||||
public float quantize(float[] src, byte[] dest, VectorSimilarityFunction similarityFunction) {
|
||||
assert src.length == dest.length;
|
||||
float correctiveOffset = 0f;
|
||||
float correction = 0;
|
||||
for (int i = 0; i < src.length; i++) {
|
||||
float v = src[i];
|
||||
// Make sure the value is within the quantile range, cutting off the tails
|
||||
// see first parenthesis in equation: byte = (float - minQuantile) * 127/(maxQuantile -
|
||||
// minQuantile)
|
||||
float dx = Math.max(minQuantile, Math.min(maxQuantile, src[i])) - minQuantile;
|
||||
// Scale the value to the range [0, 127], this is our quantized value
|
||||
// scale = 127/(maxQuantile - minQuantile)
|
||||
float dxs = scale * dx;
|
||||
// We multiply by `alpha` here to get the quantized value back into the original range
|
||||
// to aid in calculating the corrective offset
|
||||
float dxq = Math.round(dxs) * alpha;
|
||||
// Calculate the corrective offset that needs to be applied to the score
|
||||
// in addition to the `byte * minQuantile * alpha` term in the equation
|
||||
// we add the `(dx - dxq) * dxq` term to account for the fact that the quantized value
|
||||
// will be rounded to the nearest whole number and lose some accuracy
|
||||
// Additionally, we account for the global correction of `minQuantile^2` in the equation
|
||||
correctiveOffset += minQuantile * (v - minQuantile / 2.0F) + (dx - dxq) * dxq;
|
||||
dest[i] = (byte) Math.round(dxs);
|
||||
correction += quantizeFloat(src[i], dest, i);
|
||||
}
|
||||
if (similarityFunction.equals(VectorSimilarityFunction.EUCLIDEAN)) {
|
||||
return 0;
|
||||
}
|
||||
return correctiveOffset;
|
||||
return correction;
|
||||
}
|
||||
|
||||
private float quantizeFloat(float v, byte[] dest, int destIndex) {
|
||||
assert dest == null || destIndex < dest.length;
|
||||
// Make sure the value is within the quantile range, cutting off the tails
|
||||
// see first parenthesis in equation: byte = (float - minQuantile) * 127/(maxQuantile -
|
||||
// minQuantile)
|
||||
float dx = v - minQuantile;
|
||||
float dxc = Math.max(minQuantile, Math.min(maxQuantile, v)) - minQuantile;
|
||||
// Scale the value to the range [0, 127], this is our quantized value
|
||||
// scale = 127/(maxQuantile - minQuantile)
|
||||
float dxs = scale * dxc;
|
||||
// We multiply by `alpha` here to get the quantized value back into the original range
|
||||
// to aid in calculating the corrective offset
|
||||
float dxq = Math.round(dxs) * alpha;
|
||||
if (dest != null) {
|
||||
dest[destIndex] = (byte) Math.round(dxs);
|
||||
}
|
||||
// Calculate the corrective offset that needs to be applied to the score
|
||||
// in addition to the `byte * minQuantile * alpha` term in the equation
|
||||
// we add the `(dx - dxq) * dxq` term to account for the fact that the quantized value
|
||||
// will be rounded to the nearest whole number and lose some accuracy
|
||||
// Additionally, we account for the global correction of `minQuantile^2` in the equation
|
||||
return minQuantile * (v - minQuantile / 2.0F) + (dx - dxq) * dxq;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -146,10 +160,7 @@ public class ScalarQuantizer {
|
|||
for (int i = 0; i < quantizedVector.length; i++) {
|
||||
// dequantize the old value in order to recalculate the corrective offset
|
||||
float v = (oldQuantizer.alpha * quantizedVector[i]) + oldQuantizer.minQuantile;
|
||||
float dx = Math.max(minQuantile, Math.min(maxQuantile, v)) - minQuantile;
|
||||
float dxs = scale * dx;
|
||||
float dxq = Math.round(dxs) * alpha;
|
||||
correctiveOffset += minQuantile * (v - minQuantile / 2.0F) + (dx - dxq) * dxq;
|
||||
correctiveOffset += quantizeFloat(v, null, 0);
|
||||
}
|
||||
return correctiveOffset;
|
||||
}
|
||||
|
@ -160,7 +171,7 @@ public class ScalarQuantizer {
|
|||
* @param src the source vector
|
||||
* @param dest the destination vector
|
||||
*/
|
||||
public void deQuantize(byte[] src, float[] dest) {
|
||||
void deQuantize(byte[] src, float[] dest) {
|
||||
assert src.length == dest.length;
|
||||
for (int i = 0; i < src.length; i++) {
|
||||
dest[i] = (alpha * src[i]) + minQuantile;
|
||||
|
@ -175,14 +186,14 @@ public class ScalarQuantizer {
|
|||
return maxQuantile;
|
||||
}
|
||||
|
||||
public float getConfidenceInterval() {
|
||||
return confidenceInterval;
|
||||
}
|
||||
|
||||
public float getConstantMultiplier() {
|
||||
return alpha * alpha;
|
||||
}
|
||||
|
||||
public byte getBits() {
|
||||
return bits;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "ScalarQuantizer{"
|
||||
|
@ -190,14 +201,14 @@ public class ScalarQuantizer {
|
|||
+ minQuantile
|
||||
+ ", maxQuantile="
|
||||
+ maxQuantile
|
||||
+ ", confidenceInterval="
|
||||
+ confidenceInterval
|
||||
+ ", bits="
|
||||
+ bits
|
||||
+ '}';
|
||||
}
|
||||
|
||||
private static final Random random = new Random(42);
|
||||
|
||||
static int[] reservoirSampleIndices(int numFloatVecs, int sampleSize) {
|
||||
private static int[] reservoirSampleIndices(int numFloatVecs, int sampleSize) {
|
||||
int[] vectorsToTake = IntStream.range(0, sampleSize).toArray();
|
||||
for (int i = sampleSize; i < numFloatVecs; i++) {
|
||||
int j = random.nextInt(i + 1);
|
||||
|
@ -220,26 +231,35 @@ public class ScalarQuantizer {
|
|||
* @param confidenceInterval the confidence interval used to calculate the quantiles
|
||||
* @param totalVectorCount the total number of live float vectors in the index. This is vital for
|
||||
* accounting for deleted documents when calculating the quantiles.
|
||||
* @param bits the number of bits to use for quantization
|
||||
* @return A new {@link ScalarQuantizer} instance
|
||||
* @throws IOException if there is an error reading the float vector values
|
||||
*/
|
||||
public static ScalarQuantizer fromVectors(
|
||||
FloatVectorValues floatVectorValues, float confidenceInterval, int totalVectorCount)
|
||||
FloatVectorValues floatVectorValues,
|
||||
float confidenceInterval,
|
||||
int totalVectorCount,
|
||||
byte bits)
|
||||
throws IOException {
|
||||
return fromVectors(
|
||||
floatVectorValues, confidenceInterval, totalVectorCount, SCALAR_QUANTIZATION_SAMPLE_SIZE);
|
||||
floatVectorValues,
|
||||
confidenceInterval,
|
||||
totalVectorCount,
|
||||
bits,
|
||||
SCALAR_QUANTIZATION_SAMPLE_SIZE);
|
||||
}
|
||||
|
||||
static ScalarQuantizer fromVectors(
|
||||
FloatVectorValues floatVectorValues,
|
||||
float confidenceInterval,
|
||||
int totalVectorCount,
|
||||
byte bits,
|
||||
int quantizationSampleSize)
|
||||
throws IOException {
|
||||
assert 0.9f <= confidenceInterval && confidenceInterval <= 1f;
|
||||
assert quantizationSampleSize > SCRATCH_SIZE;
|
||||
if (totalVectorCount == 0) {
|
||||
return new ScalarQuantizer(0f, 0f, confidenceInterval);
|
||||
return new ScalarQuantizer(0f, 0f, bits);
|
||||
}
|
||||
if (confidenceInterval == 1f) {
|
||||
float min = Float.POSITIVE_INFINITY;
|
||||
|
@ -250,13 +270,14 @@ public class ScalarQuantizer {
|
|||
max = Math.max(max, v);
|
||||
}
|
||||
}
|
||||
return new ScalarQuantizer(min, max, confidenceInterval);
|
||||
return new ScalarQuantizer(min, max, bits);
|
||||
}
|
||||
final float[] quantileGatheringScratch =
|
||||
new float[floatVectorValues.dimension() * Math.min(SCRATCH_SIZE, totalVectorCount)];
|
||||
int count = 0;
|
||||
double upperSum = 0;
|
||||
double lowerSum = 0;
|
||||
double[] upperSum = new double[1];
|
||||
double[] lowerSum = new double[1];
|
||||
float[] confidenceIntervals = new float[] {confidenceInterval};
|
||||
if (totalVectorCount <= quantizationSampleSize) {
|
||||
int scratchSize = Math.min(SCRATCH_SIZE, totalVectorCount);
|
||||
int i = 0;
|
||||
|
@ -266,10 +287,7 @@ public class ScalarQuantizer {
|
|||
vectorValue, 0, quantileGatheringScratch, i * vectorValue.length, vectorValue.length);
|
||||
i++;
|
||||
if (i == scratchSize) {
|
||||
float[] upperAndLower =
|
||||
getUpperAndLowerQuantile(quantileGatheringScratch, confidenceInterval);
|
||||
upperSum += upperAndLower[1];
|
||||
lowerSum += upperAndLower[0];
|
||||
extractQuantiles(confidenceIntervals, quantileGatheringScratch, upperSum, lowerSum);
|
||||
i = 0;
|
||||
count++;
|
||||
}
|
||||
|
@ -277,10 +295,8 @@ public class ScalarQuantizer {
|
|||
// Note, we purposefully don't use the rest of the scratch state if we have fewer than
|
||||
// `SCRATCH_SIZE` vectors, mainly because if we are sampling so few vectors then we don't
|
||||
// want to be adversely affected by the extreme confidence intervals over small sample sizes
|
||||
return new ScalarQuantizer(
|
||||
(float) lowerSum / count, (float) upperSum / count, confidenceInterval);
|
||||
return new ScalarQuantizer((float) lowerSum[0] / count, (float) upperSum[0] / count, bits);
|
||||
}
|
||||
// Reservoir sample the vector ordinals we want to read
|
||||
int[] vectorsToTake = reservoirSampleIndices(totalVectorCount, quantizationSampleSize);
|
||||
int index = 0;
|
||||
int idx = 0;
|
||||
|
@ -296,16 +312,213 @@ public class ScalarQuantizer {
|
|||
vectorValue, 0, quantileGatheringScratch, idx * vectorValue.length, vectorValue.length);
|
||||
idx++;
|
||||
if (idx == SCRATCH_SIZE) {
|
||||
float[] upperAndLower =
|
||||
getUpperAndLowerQuantile(quantileGatheringScratch, confidenceInterval);
|
||||
upperSum += upperAndLower[1];
|
||||
lowerSum += upperAndLower[0];
|
||||
extractQuantiles(confidenceIntervals, quantileGatheringScratch, upperSum, lowerSum);
|
||||
count++;
|
||||
idx = 0;
|
||||
}
|
||||
}
|
||||
return new ScalarQuantizer(
|
||||
(float) lowerSum / count, (float) upperSum / count, confidenceInterval);
|
||||
return new ScalarQuantizer((float) lowerSum[0] / count, (float) upperSum[0] / count, bits);
|
||||
}
|
||||
|
||||
public static ScalarQuantizer fromVectorsAutoInterval(
|
||||
FloatVectorValues floatVectorValues,
|
||||
VectorSimilarityFunction function,
|
||||
int totalVectorCount,
|
||||
byte bits)
|
||||
throws IOException {
|
||||
if (totalVectorCount == 0) {
|
||||
return new ScalarQuantizer(0f, 0f, bits);
|
||||
}
|
||||
|
||||
int sampleSize = Math.min(totalVectorCount, 1000);
|
||||
final float[] quantileGatheringScratch =
|
||||
new float[floatVectorValues.dimension() * Math.min(SCRATCH_SIZE, totalVectorCount)];
|
||||
int count = 0;
|
||||
double[] upperSum = new double[2];
|
||||
double[] lowerSum = new double[2];
|
||||
final List<float[]> sampledDocs = new ArrayList<>(sampleSize);
|
||||
float[] confidenceIntervals =
|
||||
new float[] {
|
||||
1
|
||||
- Math.min(32, floatVectorValues.dimension() / 10f)
|
||||
/ (floatVectorValues.dimension() + 1),
|
||||
1 - 1f / (floatVectorValues.dimension() + 1)
|
||||
};
|
||||
if (totalVectorCount <= sampleSize) {
|
||||
int scratchSize = Math.min(SCRATCH_SIZE, totalVectorCount);
|
||||
int i = 0;
|
||||
while (floatVectorValues.nextDoc() != NO_MORE_DOCS) {
|
||||
gatherSample(floatVectorValues, quantileGatheringScratch, sampledDocs, i);
|
||||
i++;
|
||||
if (i == scratchSize) {
|
||||
extractQuantiles(confidenceIntervals, quantileGatheringScratch, upperSum, lowerSum);
|
||||
i = 0;
|
||||
count++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Reservoir sample the vector ordinals we want to read
|
||||
int[] vectorsToTake = reservoirSampleIndices(totalVectorCount, 1000);
|
||||
// TODO make this faster by .advance()ing & dual iterator
|
||||
int index = 0;
|
||||
int idx = 0;
|
||||
for (int i : vectorsToTake) {
|
||||
while (index <= i) {
|
||||
// We cannot use `advance(docId)` as MergedVectorValues does not support it
|
||||
floatVectorValues.nextDoc();
|
||||
index++;
|
||||
}
|
||||
assert floatVectorValues.docID() != NO_MORE_DOCS;
|
||||
gatherSample(floatVectorValues, quantileGatheringScratch, sampledDocs, idx);
|
||||
idx++;
|
||||
if (idx == SCRATCH_SIZE) {
|
||||
extractQuantiles(confidenceIntervals, quantileGatheringScratch, upperSum, lowerSum);
|
||||
count++;
|
||||
idx = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Here we gather the upper and lower bounds for the quantile grid search
|
||||
float al = (float) lowerSum[1] / count;
|
||||
float bu = (float) upperSum[1] / count;
|
||||
final float au = (float) lowerSum[0] / count;
|
||||
final float bl = (float) upperSum[0] / count;
|
||||
final float[] lowerCandidates = new float[16];
|
||||
final float[] upperCandidates = new float[16];
|
||||
int idx = 0;
|
||||
for (float i = 0f; i < 32f; i += 2f) {
|
||||
lowerCandidates[idx] = al + i * (au - al) / 32f;
|
||||
upperCandidates[idx] = bl + i * (bu - bl) / 32f;
|
||||
idx++;
|
||||
}
|
||||
// Now we need to find the best candidate pair by correlating the true quantized nearest
|
||||
// neighbor scores
|
||||
// with the float vector scores
|
||||
List<ScoreDocsAndScoreVariance> nearestNeighbors = findNearestNeighbors(sampledDocs, function);
|
||||
float[] bestPair =
|
||||
candidateGridSearch(
|
||||
nearestNeighbors, sampledDocs, lowerCandidates, upperCandidates, function, bits);
|
||||
return new ScalarQuantizer(bestPair[0], bestPair[1], bits);
|
||||
}
|
||||
|
||||
private static void extractQuantiles(
|
||||
float[] confidenceIntervals,
|
||||
float[] quantileGatheringScratch,
|
||||
double[] upperSum,
|
||||
double[] lowerSum) {
|
||||
assert confidenceIntervals.length == upperSum.length
|
||||
&& confidenceIntervals.length == lowerSum.length;
|
||||
for (int i = 0; i < confidenceIntervals.length; i++) {
|
||||
float[] upperAndLower =
|
||||
getUpperAndLowerQuantile(quantileGatheringScratch, confidenceIntervals[i]);
|
||||
upperSum[i] += upperAndLower[1];
|
||||
lowerSum[i] += upperAndLower[0];
|
||||
}
|
||||
}
|
||||
|
||||
private static void gatherSample(
|
||||
FloatVectorValues floatVectorValues,
|
||||
float[] quantileGatheringScratch,
|
||||
List<float[]> sampledDocs,
|
||||
int i)
|
||||
throws IOException {
|
||||
float[] vectorValue = floatVectorValues.vectorValue();
|
||||
float[] copy = new float[vectorValue.length];
|
||||
System.arraycopy(vectorValue, 0, copy, 0, vectorValue.length);
|
||||
sampledDocs.add(copy);
|
||||
System.arraycopy(
|
||||
vectorValue, 0, quantileGatheringScratch, i * vectorValue.length, vectorValue.length);
|
||||
}
|
||||
|
||||
private static float[] candidateGridSearch(
|
||||
List<ScoreDocsAndScoreVariance> nearestNeighbors,
|
||||
List<float[]> vectors,
|
||||
float[] lowerCandidates,
|
||||
float[] upperCandidates,
|
||||
VectorSimilarityFunction function,
|
||||
byte bits) {
|
||||
double maxCorr = Double.NEGATIVE_INFINITY;
|
||||
float bestLower = 0f;
|
||||
float bestUpper = 0f;
|
||||
ScoreErrorCorrelator scoreErrorCorrelator =
|
||||
new ScoreErrorCorrelator(function, nearestNeighbors, vectors, bits);
|
||||
// first do a coarse grained search to find the initial best candidate pair
|
||||
int bestQuandrantLower = 0;
|
||||
int bestQuandrantUpper = 0;
|
||||
for (int i = 0; i < lowerCandidates.length; i += 4) {
|
||||
float lower = lowerCandidates[i];
|
||||
for (int j = 0; j < upperCandidates.length; j += 4) {
|
||||
float upper = upperCandidates[j];
|
||||
if (upper <= lower) {
|
||||
continue;
|
||||
}
|
||||
double mean = scoreErrorCorrelator.scoreErrorCorrelation(lower, upper);
|
||||
if (mean > maxCorr) {
|
||||
maxCorr = mean;
|
||||
bestLower = lower;
|
||||
bestUpper = upper;
|
||||
bestQuandrantLower = i;
|
||||
bestQuandrantUpper = j;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Now search within the best quadrant
|
||||
for (int i = bestQuandrantLower + 1; i < bestQuandrantLower + 4; i++) {
|
||||
for (int j = bestQuandrantUpper + 1; j < bestQuandrantUpper + 4; j++) {
|
||||
float lower = lowerCandidates[i];
|
||||
float upper = upperCandidates[j];
|
||||
if (upper <= lower) {
|
||||
continue;
|
||||
}
|
||||
double mean = scoreErrorCorrelator.scoreErrorCorrelation(lower, upper);
|
||||
if (mean > maxCorr) {
|
||||
maxCorr = mean;
|
||||
bestLower = lower;
|
||||
bestUpper = upper;
|
||||
}
|
||||
}
|
||||
}
|
||||
return new float[] {bestLower, bestUpper};
|
||||
}
|
||||
|
||||
/**
|
||||
* @param vectors The vectors to find the nearest neighbors for each other
|
||||
* @param similarityFunction The similarity function to use
|
||||
* @return The top 10 nearest neighbors for each vector from the vectors list
|
||||
*/
|
||||
private static List<ScoreDocsAndScoreVariance> findNearestNeighbors(
|
||||
List<float[]> vectors, VectorSimilarityFunction similarityFunction) {
|
||||
List<HitQueue> queues = new ArrayList<>(vectors.size());
|
||||
queues.add(new HitQueue(10, false));
|
||||
for (int i = 0; i < vectors.size(); i++) {
|
||||
float[] vector = vectors.get(i);
|
||||
for (int j = i + 1; j < vectors.size(); j++) {
|
||||
float[] otherVector = vectors.get(j);
|
||||
float score = similarityFunction.compare(vector, otherVector);
|
||||
// initialize the rest of the queues
|
||||
if (queues.size() <= j) {
|
||||
queues.add(new HitQueue(10, false));
|
||||
}
|
||||
queues.get(i).insertWithOverflow(new ScoreDoc(j, score));
|
||||
queues.get(j).insertWithOverflow(new ScoreDoc(i, score));
|
||||
}
|
||||
}
|
||||
// Extract the top 10 from each queue
|
||||
List<ScoreDocsAndScoreVariance> result = new ArrayList<>(vectors.size());
|
||||
OnlineMeanAndVar meanAndVar = new OnlineMeanAndVar();
|
||||
for (int i = 0; i < vectors.size(); i++) {
|
||||
HitQueue queue = queues.get(i);
|
||||
ScoreDoc[] scoreDocs = new ScoreDoc[queue.size()];
|
||||
for (int j = queue.size() - 1; j >= 0; j--) {
|
||||
scoreDocs[j] = queue.pop();
|
||||
assert scoreDocs[j] != null;
|
||||
meanAndVar.add(scoreDocs[j].score);
|
||||
}
|
||||
result.add(new ScoreDocsAndScoreVariance(scoreDocs, meanAndVar.var()));
|
||||
meanAndVar.reset();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -319,9 +532,8 @@ public class ScalarQuantizer {
|
|||
* @return lower and upper quantile values
|
||||
*/
|
||||
static float[] getUpperAndLowerQuantile(float[] arr, float confidenceInterval) {
|
||||
assert 0.9f <= confidenceInterval && confidenceInterval <= 1f;
|
||||
int selectorIndex = (int) (arr.length * (1f - confidenceInterval) / 2f + 0.5f);
|
||||
if (selectorIndex > 0) {
|
||||
if (selectorIndex > 0 && arr.length > 2) {
|
||||
Selector selector = new FloatSelector(arr);
|
||||
selector.select(0, arr.length, arr.length - selectorIndex);
|
||||
selector.select(0, arr.length - selectorIndex, selectorIndex);
|
||||
|
@ -361,4 +573,95 @@ public class ScalarQuantizer {
|
|||
arr[j] = tmp;
|
||||
}
|
||||
}
|
||||
|
||||
private static class ScoreDocsAndScoreVariance {
|
||||
private final ScoreDoc[] scoreDocs;
|
||||
private final float scoreVariance;
|
||||
|
||||
public ScoreDocsAndScoreVariance(ScoreDoc[] scoreDocs, float scoreVariance) {
|
||||
this.scoreDocs = scoreDocs;
|
||||
this.scoreVariance = scoreVariance;
|
||||
}
|
||||
|
||||
public ScoreDoc[] getScoreDocs() {
|
||||
return scoreDocs;
|
||||
}
|
||||
}
|
||||
|
||||
private static class OnlineMeanAndVar {
|
||||
private double mean = 0.0;
|
||||
private double var = 0.0;
|
||||
private int n = 0;
|
||||
|
||||
void reset() {
|
||||
mean = 0.0;
|
||||
var = 0.0;
|
||||
n = 0;
|
||||
}
|
||||
|
||||
void add(double x) {
|
||||
n++;
|
||||
double delta = x - mean;
|
||||
mean += delta / n;
|
||||
var += delta * (x - mean);
|
||||
}
|
||||
|
||||
float var() {
|
||||
return (float) (var / (n - 1));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This class is used to correlate the scores of the nearest neighbors with the errors in the
|
||||
* scores. This is used to find the best quantile pair for the scalar quantizer.
|
||||
*/
|
||||
private static class ScoreErrorCorrelator {
|
||||
private final OnlineMeanAndVar corr = new OnlineMeanAndVar();
|
||||
private final OnlineMeanAndVar errors = new OnlineMeanAndVar();
|
||||
private final VectorSimilarityFunction function;
|
||||
private final List<ScoreDocsAndScoreVariance> nearestNeighbors;
|
||||
private final List<float[]> vectors;
|
||||
private final byte[] query;
|
||||
private final byte[] vector;
|
||||
private final byte bits;
|
||||
|
||||
public ScoreErrorCorrelator(
|
||||
VectorSimilarityFunction function,
|
||||
List<ScoreDocsAndScoreVariance> nearestNeighbors,
|
||||
List<float[]> vectors,
|
||||
byte bits) {
|
||||
this.function = function;
|
||||
this.nearestNeighbors = nearestNeighbors;
|
||||
this.vectors = vectors;
|
||||
this.query = new byte[vectors.get(0).length];
|
||||
this.vector = new byte[vectors.get(0).length];
|
||||
this.bits = bits;
|
||||
}
|
||||
|
||||
double scoreErrorCorrelation(float lowerQuantile, float upperQuantile) {
|
||||
corr.reset();
|
||||
ScalarQuantizer quantizer = new ScalarQuantizer(lowerQuantile, upperQuantile, bits);
|
||||
ScalarQuantizedVectorSimilarity scalarQuantizedVectorSimilarity =
|
||||
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
|
||||
function, quantizer.getConstantMultiplier(), quantizer.bits);
|
||||
for (int i = 0; i < nearestNeighbors.size(); i++) {
|
||||
float queryCorrection = quantizer.quantize(vectors.get(i), query, function);
|
||||
ScoreDocsAndScoreVariance scoreDocsAndScoreVariance = nearestNeighbors.get(i);
|
||||
ScoreDoc[] scoreDocs = scoreDocsAndScoreVariance.getScoreDocs();
|
||||
float scoreVariance = scoreDocsAndScoreVariance.scoreVariance;
|
||||
// calculate the score for the vector against its nearest neighbors but with quantized
|
||||
// scores now
|
||||
errors.reset();
|
||||
for (ScoreDoc scoreDoc : scoreDocs) {
|
||||
float vectorCorrection = quantizer.quantize(vectors.get(scoreDoc.doc), vector, function);
|
||||
float qScore =
|
||||
scalarQuantizedVectorSimilarity.score(
|
||||
query, queryCorrection, vector, vectorCorrection);
|
||||
errors.add(qScore - scoreDoc.score);
|
||||
}
|
||||
corr.add(1 - errors.var() / scoreVariance);
|
||||
}
|
||||
return corr.mean;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -389,6 +389,53 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
return acc.reduceLanes(ADD);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int int4DotProduct(byte[] a, byte[] b) {
|
||||
int i = 0;
|
||||
int res = 0;
|
||||
if (VECTOR_BITSIZE >= 512 || VECTOR_BITSIZE == 256) {
|
||||
return dotProduct(a, b);
|
||||
} else if (a.length >= 32 && HAS_FAST_INTEGER_VECTORS) {
|
||||
i += ByteVector.SPECIES_128.loopBound(a.length);
|
||||
res += int4DotProductBody128(a, b, i);
|
||||
}
|
||||
// scalar tail
|
||||
for (; i < a.length; i++) {
|
||||
res += b[i] * a[i];
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
private int int4DotProductBody128(byte[] a, byte[] b, int limit) {
|
||||
int sum = 0;
|
||||
// iterate in chunks of 1024 items to ensure we don't overflow the short accumulator
|
||||
for (int i = 0; i < limit; i += 1024) {
|
||||
ShortVector acc0 = ShortVector.zero(ShortVector.SPECIES_128);
|
||||
ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_128);
|
||||
int innerLimit = Math.min(limit - i, 1024);
|
||||
for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) {
|
||||
ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i + j);
|
||||
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i + j);
|
||||
ByteVector prod8 = va8.mul(vb8);
|
||||
ShortVector prod16 =
|
||||
prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts();
|
||||
acc0 = acc0.add(prod16.and((short) 0xFF));
|
||||
|
||||
va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i + j + 8);
|
||||
vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i + j + 8);
|
||||
prod8 = va8.mul(vb8);
|
||||
prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts();
|
||||
acc1 = acc1.add(prod16.and((short) 0xFF));
|
||||
}
|
||||
IntVector intAcc0 = acc0.convertShape(S2I, IntVector.SPECIES_128, 0).reinterpretAsInts();
|
||||
IntVector intAcc1 = acc0.convertShape(S2I, IntVector.SPECIES_128, 1).reinterpretAsInts();
|
||||
IntVector intAcc2 = acc1.convertShape(S2I, IntVector.SPECIES_128, 0).reinterpretAsInts();
|
||||
IntVector intAcc3 = acc1.convertShape(S2I, IntVector.SPECIES_128, 1).reinterpretAsInts();
|
||||
sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reduceLanes(ADD);
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float cosine(byte[] a, byte[] b) {
|
||||
int i = 0;
|
||||
|
|
|
@ -42,17 +42,37 @@ import org.apache.lucene.util.SameThreadExecutorService;
|
|||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
|
||||
import org.apache.lucene.util.quantization.ScalarQuantizer;
|
||||
import org.junit.Before;
|
||||
|
||||
public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase {
|
||||
|
||||
KnnVectorsFormat format;
|
||||
Float confidenceInterval;
|
||||
int bits;
|
||||
|
||||
@Before
|
||||
@Override
|
||||
public void setUp() throws Exception {
|
||||
bits = random().nextBoolean() ? 4 : 7;
|
||||
confidenceInterval = random().nextBoolean() ? 0.99f : null;
|
||||
format =
|
||||
new Lucene99HnswScalarQuantizedVectorsFormat(
|
||||
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
|
||||
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
|
||||
1,
|
||||
bits,
|
||||
random().nextBoolean(),
|
||||
confidenceInterval,
|
||||
null);
|
||||
super.setUp();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Codec getCodec() {
|
||||
return new Lucene99Codec() {
|
||||
@Override
|
||||
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||
return new Lucene99HnswScalarQuantizedVectorsFormat(
|
||||
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
|
||||
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
|
||||
return format;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@ -63,17 +83,25 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
|
|||
VectorSimilarityFunction similarityFunction = randomSimilarity();
|
||||
boolean normalize = similarityFunction == VectorSimilarityFunction.COSINE;
|
||||
int dim = random().nextInt(64) + 1;
|
||||
if (dim % 2 == 1) {
|
||||
dim++;
|
||||
}
|
||||
List<float[]> vectors = new ArrayList<>(numVectors);
|
||||
for (int i = 0; i < numVectors; i++) {
|
||||
vectors.add(randomVector(dim));
|
||||
}
|
||||
float confidenceInterval =
|
||||
Lucene99ScalarQuantizedVectorsFormat.calculateDefaultConfidenceInterval(dim);
|
||||
ScalarQuantizer scalarQuantizer =
|
||||
ScalarQuantizer.fromVectors(
|
||||
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, normalize),
|
||||
confidenceInterval,
|
||||
numVectors);
|
||||
confidenceInterval == null
|
||||
? ScalarQuantizer.fromVectorsAutoInterval(
|
||||
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, normalize),
|
||||
similarityFunction,
|
||||
numVectors,
|
||||
(byte) bits)
|
||||
: ScalarQuantizer.fromVectors(
|
||||
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, normalize),
|
||||
confidenceInterval,
|
||||
numVectors,
|
||||
(byte) bits);
|
||||
float[] expectedCorrections = new float[numVectors];
|
||||
byte[][] expectedVectors = new byte[numVectors][];
|
||||
for (int i = 0; i < numVectors; i++) {
|
||||
|
@ -149,11 +177,12 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
|
|||
new FilterCodec("foo", Codec.getDefault()) {
|
||||
@Override
|
||||
public KnnVectorsFormat knnVectorsFormat() {
|
||||
return new Lucene99HnswScalarQuantizedVectorsFormat(10, 20, 1, 0.9f, null);
|
||||
return new Lucene99HnswScalarQuantizedVectorsFormat(
|
||||
10, 20, 1, (byte) 4, false, 0.9f, null);
|
||||
}
|
||||
};
|
||||
String expectedString =
|
||||
"Lucene99HnswScalarQuantizedVectorsFormat(name=Lucene99HnswScalarQuantizedVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, confidenceInterval=0.9, rawVectorFormat=Lucene99FlatVectorsFormat()))";
|
||||
"Lucene99HnswScalarQuantizedVectorsFormat(name=Lucene99HnswScalarQuantizedVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, confidenceInterval=0.9, bits=4, compress=false, rawVectorFormat=Lucene99FlatVectorsFormat()))";
|
||||
assertEquals(expectedString, customCodec.knnVectorsFormat().toString());
|
||||
}
|
||||
|
||||
|
@ -174,15 +203,28 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
|
|||
() -> new Lucene99HnswScalarQuantizedVectorsFormat(20, 3201));
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> new Lucene99HnswScalarQuantizedVectorsFormat(20, 100, 0, 1.1f, null));
|
||||
() -> new Lucene99HnswScalarQuantizedVectorsFormat(20, 100, 0, 7, false, 1.1f, null));
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> new Lucene99HnswScalarQuantizedVectorsFormat(20, 100, 0, 0.8f, null));
|
||||
() -> new Lucene99HnswScalarQuantizedVectorsFormat(20, 100, 0, -1, false, null, null));
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> new Lucene99HnswScalarQuantizedVectorsFormat(20, 100, 0, 5, false, null, null));
|
||||
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> new Lucene99HnswScalarQuantizedVectorsFormat(20, 100, 0, 9, false, null, null));
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> new Lucene99HnswScalarQuantizedVectorsFormat(20, 100, 0, 7, false, 0.8f, null));
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> new Lucene99HnswScalarQuantizedVectorsFormat(20, 100, 100, 7, false, null, null));
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() ->
|
||||
new Lucene99HnswScalarQuantizedVectorsFormat(
|
||||
20, 100, 1, null, new SameThreadExecutorService()));
|
||||
20, 100, 1, 7, false, null, new SameThreadExecutorService()));
|
||||
}
|
||||
|
||||
// Ensures that all expected vector similarity functions are translatable
|
||||
|
|
|
@ -39,14 +39,17 @@ public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
|
|||
float error = Math.max((100 - confidenceInterval) * 0.01f, 0.01f);
|
||||
FloatVectorValues floatVectorValues = fromFloats(floats);
|
||||
ScalarQuantizer scalarQuantizer =
|
||||
ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, floats.length);
|
||||
ScalarQuantizer.fromVectors(
|
||||
floatVectorValues, confidenceInterval, floats.length, (byte) 7);
|
||||
byte[][] quantized = new byte[floats.length][];
|
||||
float[] offsets =
|
||||
quantizeVectors(scalarQuantizer, floats, quantized, VectorSimilarityFunction.EUCLIDEAN);
|
||||
float[] query = ArrayUtil.copyOfSubArray(floats[0], 0, dims);
|
||||
ScalarQuantizedVectorSimilarity quantizedSimilarity =
|
||||
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
|
||||
VectorSimilarityFunction.EUCLIDEAN, scalarQuantizer.getConstantMultiplier());
|
||||
VectorSimilarityFunction.EUCLIDEAN,
|
||||
scalarQuantizer.getConstantMultiplier(),
|
||||
scalarQuantizer.getBits());
|
||||
assertQuantizedScores(
|
||||
floats,
|
||||
quantized,
|
||||
|
@ -69,7 +72,8 @@ public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
|
|||
float error = Math.max((100 - confidenceInterval) * 0.01f, 0.01f);
|
||||
FloatVectorValues floatVectorValues = fromFloatsNormalized(floats, null);
|
||||
ScalarQuantizer scalarQuantizer =
|
||||
ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, floats.length);
|
||||
ScalarQuantizer.fromVectors(
|
||||
floatVectorValues, confidenceInterval, floats.length, (byte) 7);
|
||||
byte[][] quantized = new byte[floats.length][];
|
||||
float[] offsets =
|
||||
quantizeVectorsNormalized(
|
||||
|
@ -78,7 +82,9 @@ public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
|
|||
VectorUtil.l2normalize(query);
|
||||
ScalarQuantizedVectorSimilarity quantizedSimilarity =
|
||||
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
|
||||
VectorSimilarityFunction.COSINE, scalarQuantizer.getConstantMultiplier());
|
||||
VectorSimilarityFunction.COSINE,
|
||||
scalarQuantizer.getConstantMultiplier(),
|
||||
scalarQuantizer.getBits());
|
||||
assertQuantizedScores(
|
||||
floats,
|
||||
quantized,
|
||||
|
@ -103,7 +109,8 @@ public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
|
|||
float error = Math.max((100 - confidenceInterval) * 0.01f, 0.01f);
|
||||
FloatVectorValues floatVectorValues = fromFloats(floats);
|
||||
ScalarQuantizer scalarQuantizer =
|
||||
ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, floats.length);
|
||||
ScalarQuantizer.fromVectors(
|
||||
floatVectorValues, confidenceInterval, floats.length, (byte) 7);
|
||||
byte[][] quantized = new byte[floats.length][];
|
||||
float[] offsets =
|
||||
quantizeVectors(scalarQuantizer, floats, quantized, VectorSimilarityFunction.DOT_PRODUCT);
|
||||
|
@ -111,7 +118,9 @@ public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
|
|||
VectorUtil.l2normalize(query);
|
||||
ScalarQuantizedVectorSimilarity quantizedSimilarity =
|
||||
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
|
||||
VectorSimilarityFunction.DOT_PRODUCT, scalarQuantizer.getConstantMultiplier());
|
||||
VectorSimilarityFunction.DOT_PRODUCT,
|
||||
scalarQuantizer.getConstantMultiplier(),
|
||||
scalarQuantizer.getBits());
|
||||
assertQuantizedScores(
|
||||
floats,
|
||||
quantized,
|
||||
|
@ -133,7 +142,8 @@ public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
|
|||
float error = Math.max((100 - confidenceInterval) * 0.5f, 0.5f);
|
||||
FloatVectorValues floatVectorValues = fromFloats(floats);
|
||||
ScalarQuantizer scalarQuantizer =
|
||||
ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, floats.length);
|
||||
ScalarQuantizer.fromVectors(
|
||||
floatVectorValues, confidenceInterval, floats.length, (byte) 7);
|
||||
byte[][] quantized = new byte[floats.length][];
|
||||
float[] offsets =
|
||||
quantizeVectors(
|
||||
|
@ -142,7 +152,8 @@ public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
|
|||
ScalarQuantizedVectorSimilarity quantizedSimilarity =
|
||||
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
|
||||
VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT,
|
||||
scalarQuantizer.getConstantMultiplier());
|
||||
scalarQuantizer.getConstantMultiplier(),
|
||||
scalarQuantizer.getBits());
|
||||
assertQuantizedScores(
|
||||
floats,
|
||||
quantized,
|
||||
|
|
|
@ -34,7 +34,8 @@ public class TestScalarQuantizer extends LuceneTestCase {
|
|||
|
||||
float[][] floats = randomFloats(numVecs, dims);
|
||||
FloatVectorValues floatVectorValues = fromFloats(floats);
|
||||
ScalarQuantizer scalarQuantizer = ScalarQuantizer.fromVectors(floatVectorValues, 1, numVecs);
|
||||
ScalarQuantizer scalarQuantizer =
|
||||
ScalarQuantizer.fromVectors(floatVectorValues, 1, numVecs, (byte) 7);
|
||||
float[] dequantized = new float[dims];
|
||||
byte[] quantized = new byte[dims];
|
||||
byte[] requantized = new byte[dims];
|
||||
|
@ -87,6 +88,7 @@ public class TestScalarQuantizer extends LuceneTestCase {
|
|||
floatVectorValues,
|
||||
0.99f,
|
||||
floatVectorValues.numLiveVectors,
|
||||
(byte) 7,
|
||||
Math.max(floatVectorValues.numLiveVectors - 1, SCRATCH_SIZE + 1));
|
||||
}
|
||||
{
|
||||
|
@ -96,6 +98,7 @@ public class TestScalarQuantizer extends LuceneTestCase {
|
|||
floatVectorValues,
|
||||
0.99f,
|
||||
floatVectorValues.numLiveVectors,
|
||||
(byte) 7,
|
||||
Math.max(floatVectorValues.numLiveVectors - 1, SCRATCH_SIZE + 1));
|
||||
}
|
||||
{
|
||||
|
@ -105,6 +108,7 @@ public class TestScalarQuantizer extends LuceneTestCase {
|
|||
floatVectorValues,
|
||||
0.99f,
|
||||
floatVectorValues.numLiveVectors,
|
||||
(byte) 7,
|
||||
Math.max(floatVectorValues.numLiveVectors - 1, SCRATCH_SIZE + 1));
|
||||
}
|
||||
{
|
||||
|
@ -114,10 +118,36 @@ public class TestScalarQuantizer extends LuceneTestCase {
|
|||
floatVectorValues,
|
||||
0.99f,
|
||||
floatVectorValues.numLiveVectors,
|
||||
(byte) 7,
|
||||
Math.max(random().nextInt(floatVectorValues.floats.length - 1) + 1, SCRATCH_SIZE + 1));
|
||||
}
|
||||
}
|
||||
|
||||
public void testFromVectorsAutoInterval() throws IOException {
|
||||
int dims = 128;
|
||||
int numVecs = 100;
|
||||
VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
||||
|
||||
float[][] floats = randomFloats(numVecs, dims);
|
||||
FloatVectorValues floatVectorValues = fromFloats(floats);
|
||||
ScalarQuantizer scalarQuantizer =
|
||||
ScalarQuantizer.fromVectorsAutoInterval(
|
||||
floatVectorValues, similarityFunction, numVecs, (byte) 4);
|
||||
assertNotNull(scalarQuantizer);
|
||||
float[] dequantized = new float[dims];
|
||||
byte[] quantized = new byte[dims];
|
||||
byte[] requantized = new byte[dims];
|
||||
for (int i = 0; i < numVecs; i++) {
|
||||
scalarQuantizer.quantize(floats[i], quantized, similarityFunction);
|
||||
scalarQuantizer.deQuantize(quantized, dequantized);
|
||||
scalarQuantizer.quantize(dequantized, requantized, similarityFunction);
|
||||
for (int j = 0; j < dims; j++) {
|
||||
assertEquals(dequantized[j], floats[i][j], 0.2);
|
||||
assertEquals(quantized[j], requantized[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void shuffleArray(float[] ar) {
|
||||
for (int i = ar.length - 1; i > 0; i--) {
|
||||
int index = random().nextInt(i + 1);
|
||||
|
|
|
@ -127,12 +127,12 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
w.addDocument(doc);
|
||||
|
||||
Document doc2 = new Document();
|
||||
doc2.add(new KnnFloatVectorField("f", new float[3], VectorSimilarityFunction.DOT_PRODUCT));
|
||||
doc2.add(new KnnFloatVectorField("f", new float[6], VectorSimilarityFunction.DOT_PRODUCT));
|
||||
IllegalArgumentException expected =
|
||||
expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc2));
|
||||
String errMsg =
|
||||
"Inconsistency of field data structures across documents for field [f] of doc [1]."
|
||||
+ " vector dimension: expected '4', but it has '3'.";
|
||||
+ " vector dimension: expected '4', but it has '6'.";
|
||||
assertEquals(errMsg, expected.getMessage());
|
||||
}
|
||||
|
||||
|
@ -145,12 +145,12 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
w.commit();
|
||||
|
||||
Document doc2 = new Document();
|
||||
doc2.add(new KnnFloatVectorField("f", new float[3], VectorSimilarityFunction.DOT_PRODUCT));
|
||||
doc2.add(new KnnFloatVectorField("f", new float[6], VectorSimilarityFunction.DOT_PRODUCT));
|
||||
IllegalArgumentException expected =
|
||||
expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc2));
|
||||
String errMsg =
|
||||
"cannot change field \"f\" from vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
|
||||
+ "to inconsistent vector dimension=3, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT";
|
||||
+ "to inconsistent vector dimension=6, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT";
|
||||
assertEquals(errMsg, expected.getMessage());
|
||||
}
|
||||
}
|
||||
|
@ -202,12 +202,12 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
|
||||
try (IndexWriter w2 = new IndexWriter(dir, newIndexWriterConfig())) {
|
||||
Document doc2 = new Document();
|
||||
doc2.add(new KnnFloatVectorField("f", new float[1], VectorSimilarityFunction.DOT_PRODUCT));
|
||||
doc2.add(new KnnFloatVectorField("f", new float[2], VectorSimilarityFunction.DOT_PRODUCT));
|
||||
IllegalArgumentException expected =
|
||||
expectThrows(IllegalArgumentException.class, () -> w2.addDocument(doc2));
|
||||
assertEquals(
|
||||
"cannot change field \"f\" from vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
|
||||
+ "to inconsistent vector dimension=1, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
|
||||
+ "to inconsistent vector dimension=2, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
|
||||
expected.getMessage());
|
||||
}
|
||||
}
|
||||
|
@ -284,7 +284,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
|
||||
public void testAddIndexesDirectory01() throws Exception {
|
||||
String fieldName = "field";
|
||||
float[] vector = new float[1];
|
||||
float[] vector = new float[2];
|
||||
Document doc = new Document();
|
||||
doc.add(new KnnFloatVectorField(fieldName, vector, VectorSimilarityFunction.DOT_PRODUCT));
|
||||
try (Directory dir = newDirectory();
|
||||
|
@ -294,6 +294,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
}
|
||||
try (IndexWriter w2 = new IndexWriter(dir2, newIndexWriterConfig())) {
|
||||
vector[0] = 1;
|
||||
vector[1] = 1;
|
||||
w2.addDocument(doc);
|
||||
w2.addIndexes(dir);
|
||||
w2.forceMerge(1);
|
||||
|
@ -322,13 +323,13 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
}
|
||||
try (IndexWriter w2 = new IndexWriter(dir2, newIndexWriterConfig())) {
|
||||
Document doc = new Document();
|
||||
doc.add(new KnnFloatVectorField("f", new float[5], VectorSimilarityFunction.DOT_PRODUCT));
|
||||
doc.add(new KnnFloatVectorField("f", new float[6], VectorSimilarityFunction.DOT_PRODUCT));
|
||||
w2.addDocument(doc);
|
||||
IllegalArgumentException expected =
|
||||
expectThrows(
|
||||
IllegalArgumentException.class, () -> w2.addIndexes(new Directory[] {dir}));
|
||||
assertEquals(
|
||||
"cannot change field \"f\" from vector dimension=5, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
|
||||
"cannot change field \"f\" from vector dimension=6, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
|
||||
+ "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
|
||||
expected.getMessage());
|
||||
}
|
||||
|
@ -367,7 +368,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
}
|
||||
try (IndexWriter w2 = new IndexWriter(dir2, newIndexWriterConfig())) {
|
||||
Document doc = new Document();
|
||||
doc.add(new KnnFloatVectorField("f", new float[5], VectorSimilarityFunction.DOT_PRODUCT));
|
||||
doc.add(new KnnFloatVectorField("f", new float[6], VectorSimilarityFunction.DOT_PRODUCT));
|
||||
w2.addDocument(doc);
|
||||
try (DirectoryReader r = DirectoryReader.open(dir)) {
|
||||
IllegalArgumentException expected =
|
||||
|
@ -375,7 +376,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
IllegalArgumentException.class,
|
||||
() -> w2.addIndexes(new CodecReader[] {(CodecReader) getOnlyLeafReader(r)}));
|
||||
assertEquals(
|
||||
"cannot change field \"f\" from vector dimension=5, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
|
||||
"cannot change field \"f\" from vector dimension=6, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
|
||||
+ "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
|
||||
expected.getMessage());
|
||||
}
|
||||
|
@ -419,13 +420,13 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
}
|
||||
try (IndexWriter w2 = new IndexWriter(dir2, newIndexWriterConfig())) {
|
||||
Document doc = new Document();
|
||||
doc.add(new KnnFloatVectorField("f", new float[5], VectorSimilarityFunction.DOT_PRODUCT));
|
||||
doc.add(new KnnFloatVectorField("f", new float[6], VectorSimilarityFunction.DOT_PRODUCT));
|
||||
w2.addDocument(doc);
|
||||
try (DirectoryReader r = DirectoryReader.open(dir)) {
|
||||
IllegalArgumentException expected =
|
||||
expectThrows(IllegalArgumentException.class, () -> TestUtil.addIndexesSlowly(w2, r));
|
||||
assertEquals(
|
||||
"cannot change field \"f\" from vector dimension=5, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
|
||||
"cannot change field \"f\" from vector dimension=6, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
|
||||
+ "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
|
||||
expected.getMessage());
|
||||
}
|
||||
|
@ -486,7 +487,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
.contains("vector's dimensions must be <= [" + getVectorsMaxDimensions("f") + "]"));
|
||||
|
||||
Document doc2 = new Document();
|
||||
doc2.add(new KnnFloatVectorField("f", new float[1], VectorSimilarityFunction.DOT_PRODUCT));
|
||||
doc2.add(new KnnFloatVectorField("f", new float[2], VectorSimilarityFunction.DOT_PRODUCT));
|
||||
w.addDocument(doc2);
|
||||
|
||||
Document doc3 = new Document();
|
||||
|
@ -531,7 +532,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
assertEquals("cannot index an empty vector", e.getMessage());
|
||||
|
||||
Document doc2 = new Document();
|
||||
doc2.add(new KnnFloatVectorField("f", new float[1], VectorSimilarityFunction.EUCLIDEAN));
|
||||
doc2.add(new KnnFloatVectorField("f", new float[2], VectorSimilarityFunction.EUCLIDEAN));
|
||||
w.addDocument(doc2);
|
||||
}
|
||||
}
|
||||
|
@ -592,7 +593,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
doc.add(new StringField("id", "0", Field.Store.NO));
|
||||
doc.add(
|
||||
new KnnFloatVectorField(
|
||||
"v", new float[] {2, 3, 5}, VectorSimilarityFunction.DOT_PRODUCT));
|
||||
"v", new float[] {2, 3, 5, 6}, VectorSimilarityFunction.DOT_PRODUCT));
|
||||
w.addDocument(doc);
|
||||
w.addDocument(new Document());
|
||||
w.commit();
|
||||
|
@ -613,7 +614,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
// assert that knn search doesn't fail on a field with all deleted docs
|
||||
TopDocs results =
|
||||
leafReader.searchNearestVectors(
|
||||
"v", randomNormalizedVector(3), 1, leafReader.getLiveDocs(), Integer.MAX_VALUE);
|
||||
"v", randomNormalizedVector(4), 1, leafReader.getLiveDocs(), Integer.MAX_VALUE);
|
||||
assertEquals(0, results.scoreDocs.length);
|
||||
}
|
||||
}
|
||||
|
@ -626,14 +627,14 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
doc.add(new StringField("id", "0", Field.Store.NO));
|
||||
doc.add(
|
||||
new KnnFloatVectorField(
|
||||
"v0", new float[] {2, 3, 5}, VectorSimilarityFunction.DOT_PRODUCT));
|
||||
"v0", new float[] {2, 3, 5, 6}, VectorSimilarityFunction.DOT_PRODUCT));
|
||||
w.addDocument(doc);
|
||||
w.commit();
|
||||
|
||||
doc = new Document();
|
||||
doc.add(
|
||||
new KnnFloatVectorField(
|
||||
"v1", new float[] {2, 3, 5}, VectorSimilarityFunction.DOT_PRODUCT));
|
||||
"v1", new float[] {2, 3, 5, 6}, VectorSimilarityFunction.DOT_PRODUCT));
|
||||
w.addDocument(doc);
|
||||
w.forceMerge(1);
|
||||
}
|
||||
|
@ -649,6 +650,9 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
VectorEncoding[] fieldVectorEncodings = new VectorEncoding[numFields];
|
||||
for (int i = 0; i < numFields; i++) {
|
||||
fieldDims[i] = random().nextInt(20) + 1;
|
||||
if (fieldDims[i] % 2 != 0) {
|
||||
fieldDims[i]++;
|
||||
}
|
||||
fieldSimilarityFunctions[i] = randomSimilarity();
|
||||
fieldVectorEncodings[i] = randomVectorEncoding();
|
||||
}
|
||||
|
@ -731,7 +735,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
// We copy indexed values (as for BinaryDocValues) so the input float[] can be reused across
|
||||
// calls to IndexWriter.addDocument.
|
||||
String fieldName = "field";
|
||||
float[] v = {0};
|
||||
float[] v = {0, 0};
|
||||
try (Directory dir = newDirectory();
|
||||
IndexWriter iw =
|
||||
new IndexWriter(
|
||||
|
@ -829,25 +833,25 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
IndexWriter iw =
|
||||
new IndexWriter(dir, newIndexWriterConfig().setMergePolicy(newLogMergePolicy()))) {
|
||||
Document doc = new Document();
|
||||
float[] v = new float[] {1};
|
||||
float[] v = new float[] {1, 2};
|
||||
doc.add(new KnnFloatVectorField("field1", v, VectorSimilarityFunction.EUCLIDEAN));
|
||||
doc.add(
|
||||
new KnnFloatVectorField(
|
||||
"field2", new float[] {1, 2, 3}, VectorSimilarityFunction.EUCLIDEAN));
|
||||
"field2", new float[] {1, 2, 3, 4}, VectorSimilarityFunction.EUCLIDEAN));
|
||||
iw.addDocument(doc);
|
||||
v[0] = 2;
|
||||
iw.addDocument(doc);
|
||||
doc = new Document();
|
||||
doc.add(
|
||||
new KnnFloatVectorField(
|
||||
"field3", new float[] {1, 2, 3}, VectorSimilarityFunction.DOT_PRODUCT));
|
||||
"field3", new float[] {1, 2, 3, 4}, VectorSimilarityFunction.DOT_PRODUCT));
|
||||
iw.addDocument(doc);
|
||||
iw.forceMerge(1);
|
||||
try (IndexReader reader = DirectoryReader.open(iw)) {
|
||||
LeafReader leaf = reader.leaves().get(0).reader();
|
||||
|
||||
FloatVectorValues vectorValues = leaf.getFloatVectorValues("field1");
|
||||
assertEquals(1, vectorValues.dimension());
|
||||
assertEquals(2, vectorValues.dimension());
|
||||
assertEquals(2, vectorValues.size());
|
||||
vectorValues.nextDoc();
|
||||
assertEquals(1f, vectorValues.vectorValue()[0], 0);
|
||||
|
@ -856,7 +860,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
assertEquals(NO_MORE_DOCS, vectorValues.nextDoc());
|
||||
|
||||
FloatVectorValues vectorValues2 = leaf.getFloatVectorValues("field2");
|
||||
assertEquals(3, vectorValues2.dimension());
|
||||
assertEquals(4, vectorValues2.dimension());
|
||||
assertEquals(2, vectorValues2.size());
|
||||
vectorValues2.nextDoc();
|
||||
assertEquals(2f, vectorValues2.vectorValue()[1], 0);
|
||||
|
@ -865,7 +869,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
assertEquals(NO_MORE_DOCS, vectorValues2.nextDoc());
|
||||
|
||||
FloatVectorValues vectorValues3 = leaf.getFloatVectorValues("field3");
|
||||
assertEquals(3, vectorValues3.dimension());
|
||||
assertEquals(4, vectorValues3.dimension());
|
||||
assertEquals(1, vectorValues3.size());
|
||||
vectorValues3.nextDoc();
|
||||
assertEquals(1f, vectorValues3.vectorValue()[0], 0.1);
|
||||
|
@ -889,6 +893,9 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
IndexWriter iw = new IndexWriter(dir, iwc)) {
|
||||
int numDoc = atLeast(100);
|
||||
int dimension = atLeast(10);
|
||||
if (dimension % 2 != 0) {
|
||||
dimension++;
|
||||
}
|
||||
float[] scratch = new float[dimension];
|
||||
int numValues = 0;
|
||||
float[][] values = new float[numDoc][];
|
||||
|
@ -965,6 +972,9 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
IndexWriter iw = new IndexWriter(dir, iwc)) {
|
||||
int numDoc = atLeast(100);
|
||||
int dimension = atLeast(10);
|
||||
if (dimension % 2 != 0) {
|
||||
dimension++;
|
||||
}
|
||||
byte[] scratch = new byte[dimension];
|
||||
int numValues = 0;
|
||||
BytesRef[] values = new BytesRef[numDoc];
|
||||
|
@ -1101,6 +1111,9 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
IndexWriter iw = new IndexWriter(dir, iwc)) {
|
||||
int numDoc = atLeast(100);
|
||||
int dimension = atLeast(10);
|
||||
if (dimension % 2 != 0) {
|
||||
dimension++;
|
||||
}
|
||||
float[][] id2value = new float[numDoc][];
|
||||
for (int i = 0; i < numDoc; i++) {
|
||||
int id = random().nextInt(numDoc);
|
||||
|
@ -1252,7 +1265,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
return v;
|
||||
}
|
||||
|
||||
private byte[] randomVector8(int dim) {
|
||||
protected byte[] randomVector8(int dim) {
|
||||
assert dim > 0;
|
||||
float[] v = randomNormalizedVector(dim);
|
||||
byte[] b = new byte[dim];
|
||||
|
@ -1268,12 +1281,12 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
Document doc = new Document();
|
||||
doc.add(
|
||||
new KnnFloatVectorField(
|
||||
"v1", randomNormalizedVector(3), VectorSimilarityFunction.EUCLIDEAN));
|
||||
"v1", randomNormalizedVector(4), VectorSimilarityFunction.EUCLIDEAN));
|
||||
w.addDocument(doc);
|
||||
|
||||
doc.add(
|
||||
new KnnFloatVectorField(
|
||||
"v2", randomNormalizedVector(3), VectorSimilarityFunction.EUCLIDEAN));
|
||||
"v2", randomNormalizedVector(4), VectorSimilarityFunction.EUCLIDEAN));
|
||||
w.addDocument(doc);
|
||||
}
|
||||
|
||||
|
@ -1360,7 +1373,10 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
|
||||
public void testVectorValuesReportCorrectDocs() throws Exception {
|
||||
final int numDocs = atLeast(1000);
|
||||
final int dim = random().nextInt(20) + 1;
|
||||
int dim = random().nextInt(20) + 1;
|
||||
if (dim % 2 != 0) {
|
||||
dim++;
|
||||
}
|
||||
double fieldValuesCheckSum = 0;
|
||||
int fieldDocCount = 0;
|
||||
long fieldSumDocIDs = 0;
|
||||
|
|
Loading…
Reference in New Issue