Refactoring HNSW to use a new internal FlatVectorFormat (#12729)

Currently the HNSW codec does too many things, it not only indexes vectors, but stores them and determines how to store them given the vector type.

This PR extracts out the vector storage into a new format `Lucene99FlatVectorsFormat` and adds new base class called `FlatVectorsFormat`. This allows for some additional helper functions that allow an indexing codec (like HNSW) take advantage of the flat formats.

Additionally, this PR refactors the new `Lucene99ScalarQuantizedVectorsFormat` to be a `FlatVectorsFormat`.

Now, `Lucene99HnswVectorsFormat` is constructed with a `Lucene99FlatVectorsFormat` and a new `Lucene99HnswScalarQuantizedVectorsFormat` that uses `Lucene99ScalarQuantizedVectorsFormat`
This commit is contained in:
Benjamin Trent 2023-11-10 14:05:19 -05:00 committed by GitHub
parent c28d174cd7
commit a47ba3369f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 2229 additions and 1015 deletions

View File

@ -184,6 +184,10 @@ New Features
* GITHUB#12660: HNSW graph now can be merged with multiple thread. Configurable in Lucene99HnswVectorsFormat. * GITHUB#12660: HNSW graph now can be merged with multiple thread. Configurable in Lucene99HnswVectorsFormat.
(Patrick Zhai) (Patrick Zhai)
* GITHUB#12729: Add new Lucene99FlatVectorsFormat for writing vectors in a flat format and refactor
Lucene99ScalarQuantizedVectorsFormat & Lucene99HnswVectorsFormat to reuse the flat formats.
Additionally, this allows flat formats to be pluggable independent of HNSW. (Ben Trent)
Improvements Improvements
--------------------- ---------------------
* GITHUB#12523: TaskExecutor waits for all tasks to complete before returning when Exceptions * GITHUB#12523: TaskExecutor waits for all tasks to complete before returning when Exceptions

View File

@ -80,8 +80,6 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
} }
} }
abstract Bits getAcceptOrds(Bits acceptDocs);
static class DenseOffHeapVectorValues extends OffHeapFloatVectorValues { static class DenseOffHeapVectorValues extends OffHeapFloatVectorValues {
private int doc = -1; private int doc = -1;
@ -120,7 +118,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
} }
@Override @Override
Bits getAcceptOrds(Bits acceptDocs) { public Bits getAcceptOrds(Bits acceptDocs) {
return acceptDocs; return acceptDocs;
} }
} }
@ -184,7 +182,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
} }
@Override @Override
Bits getAcceptOrds(Bits acceptDocs) { public Bits getAcceptOrds(Bits acceptDocs) {
if (acceptDocs == null) { if (acceptDocs == null) {
return null; return null;
} }
@ -256,7 +254,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
} }
@Override @Override
Bits getAcceptOrds(Bits acceptDocs) { public Bits getAcceptOrds(Bits acceptDocs) {
return null; return null;
} }
} }

View File

@ -89,8 +89,6 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
} }
} }
abstract Bits getAcceptOrds(Bits acceptDocs);
static class DenseOffHeapVectorValues extends OffHeapByteVectorValues { static class DenseOffHeapVectorValues extends OffHeapByteVectorValues {
private int doc = -1; private int doc = -1;
@ -129,7 +127,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
} }
@Override @Override
Bits getAcceptOrds(Bits acceptDocs) { public Bits getAcceptOrds(Bits acceptDocs) {
return acceptDocs; return acceptDocs;
} }
} }
@ -196,7 +194,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
} }
@Override @Override
Bits getAcceptOrds(Bits acceptDocs) { public Bits getAcceptOrds(Bits acceptDocs) {
if (acceptDocs == null) { if (acceptDocs == null) {
return null; return null;
} }
@ -268,7 +266,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
} }
@Override @Override
Bits getAcceptOrds(Bits acceptDocs) { public Bits getAcceptOrds(Bits acceptDocs) {
return null; return null;
} }
} }

View File

@ -86,8 +86,6 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
} }
} }
abstract Bits getAcceptOrds(Bits acceptDocs);
static class DenseOffHeapVectorValues extends OffHeapFloatVectorValues { static class DenseOffHeapVectorValues extends OffHeapFloatVectorValues {
private int doc = -1; private int doc = -1;
@ -126,7 +124,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
} }
@Override @Override
Bits getAcceptOrds(Bits acceptDocs) { public Bits getAcceptOrds(Bits acceptDocs) {
return acceptDocs; return acceptDocs;
} }
} }
@ -193,7 +191,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
} }
@Override @Override
Bits getAcceptOrds(Bits acceptDocs) { public Bits getAcceptOrds(Bits acceptDocs) {
if (acceptDocs == null) { if (acceptDocs == null) {
return null; return null;
} }
@ -265,7 +263,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
} }
@Override @Override
Bits getAcceptOrds(Bits acceptDocs) { public Bits getAcceptOrds(Bits acceptDocs) {
return null; return null;
} }
} }

View File

@ -16,7 +16,6 @@
*/ */
import org.apache.lucene.codecs.lucene99.Lucene99Codec; import org.apache.lucene.codecs.lucene99.Lucene99Codec;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
/** Lucene Core. */ /** Lucene Core. */
@SuppressWarnings("module") // the test framework is compiled after the core... @SuppressWarnings("module") // the test framework is compiled after the core...
@ -70,7 +69,8 @@ module org.apache.lucene.core {
provides org.apache.lucene.codecs.DocValuesFormat with provides org.apache.lucene.codecs.DocValuesFormat with
org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat; org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
provides org.apache.lucene.codecs.KnnVectorsFormat with provides org.apache.lucene.codecs.KnnVectorsFormat with
Lucene99HnswVectorsFormat; org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat,
org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat;
provides org.apache.lucene.codecs.PostingsFormat with provides org.apache.lucene.codecs.PostingsFormat with
org.apache.lucene.codecs.lucene99.Lucene99PostingsFormat; org.apache.lucene.codecs.lucene99.Lucene99PostingsFormat;
provides org.apache.lucene.index.SortFieldProvider with provides org.apache.lucene.index.SortFieldProvider with

View File

@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs;
/**
* Vectors' writer for a field
*
* @param <T> an array type; the type of vectors to be written
* @lucene.experimental
*/
public abstract class FlatFieldVectorsWriter<T> extends KnnFieldVectorsWriter<T> {
/**
* The delegate to write to, can be null When non-null, all vectors seen should be written to the
* delegate along with being written to the flat vectors.
*/
protected final KnnFieldVectorsWriter<T> indexingDelegate;
/**
* Sole constructor that expects some indexingDelegate. All vectors seen should be written to the
* delegate along with being written to the flat vectors.
*
* @param indexingDelegate the delegate to write to, can be null
*/
protected FlatFieldVectorsWriter(KnnFieldVectorsWriter<T> indexingDelegate) {
this.indexingDelegate = indexingDelegate;
}
}

View File

@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs;
import java.io.IOException;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
/**
* Encodes/decodes per-document vectors
*
* @lucene.experimental
*/
public abstract class FlatVectorsFormat {
/** Sole constructor */
protected FlatVectorsFormat() {}
/** Returns a {@link FlatVectorsWriter} to write the vectors to the index. */
public abstract FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException;
/** Returns a {@link KnnVectorsReader} to read the vectors from the index. */
public abstract FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException;
}

View File

@ -0,0 +1,92 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs;
import java.io.Closeable;
import java.io.IOException;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
/**
* Reads vectors from an index. When searching this reader, it iterates every vector in the index
* and scores them
*
* <p>This class is useful when:
*
* <ul>
* <li>the number of vectors is small
* <li>when used along side some additional indexing structure that can be used to better search
* the vectors (like HNSW).
* </ul>
*
* @lucene.experimental
*/
public abstract class FlatVectorsReader implements Closeable, Accountable {
/** Sole constructor */
protected FlatVectorsReader() {}
/**
* Returns a {@link RandomVectorScorer} for the given field and target vector.
*
* @param field the field to search
* @param target the target vector
* @return a {@link RandomVectorScorer} for the given field and target vector.
* @throws IOException if an I/O error occurs when reading from the index.
*/
public abstract RandomVectorScorer getRandomVectorScorer(String field, float[] target)
throws IOException;
/**
* Returns a {@link RandomVectorScorer} for the given field and target vector.
*
* @param field the field to search
* @param target the target vector
* @return a {@link RandomVectorScorer} for the given field and target vector.
* @throws IOException if an I/O error occurs when reading from the index.
*/
public abstract RandomVectorScorer getRandomVectorScorer(String field, byte[] target)
throws IOException;
/**
* Checks consistency of this reader.
*
* <p>Note that this may be costly in terms of I/O, e.g. may involve computing a checksum value
* against large data files.
*
* @lucene.internal
*/
public abstract void checkIntegrity() throws IOException;
/**
* Returns the {@link FloatVectorValues} for the given {@code field}. The behavior is undefined if
* the given field doesn't have KNN vectors enabled on its {@link FieldInfo}. The return value is
* never {@code null}.
*/
public abstract FloatVectorValues getFloatVectorValues(String field) throws IOException;
/**
* Returns the {@link ByteVectorValues} for the given {@code field}. The behavior is undefined if
* the given field doesn't have KNN vectors enabled on its {@link FieldInfo}. The return value is
* never {@code null}.
*/
public abstract ByteVectorValues getByteVectorValues(String field) throws IOException;
}

View File

@ -0,0 +1,74 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs;
import java.io.Closeable;
import java.io.IOException;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
/**
* Vectors' writer for a field that allows additional indexing logic to be implemented by the caller
*
* @lucene.experimental
*/
public abstract class FlatVectorsWriter implements Accountable, Closeable {
/** Sole constructor */
protected FlatVectorsWriter() {}
/**
* Add a new field for indexing, allowing the user to provide a writer that the flat vectors
* writer can delegate to if additional indexing logic is required.
*
* @param fieldInfo fieldInfo of the field to add
* @param indexWriter the writer to delegate to, can be null
* @return a writer for the field
* @throws IOException if an I/O error occurs when adding the field
*/
public abstract FlatFieldVectorsWriter<?> addField(
FieldInfo fieldInfo, KnnFieldVectorsWriter<?> indexWriter) throws IOException;
/**
* Write the field for merging, providing a scorer over the newly merged flat vectors. This way
* any additional merging logic can be implemented by the user of this class.
*
* @param fieldInfo fieldInfo of the field to merge
* @param mergeState mergeState of the segments to merge
* @return a scorer over the newly merged flat vectors, which should be closed as it holds a
* temporary file handle to read over the newly merged vectors
* @throws IOException if an I/O error occurs when merging
*/
public abstract CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(
FieldInfo fieldInfo, MergeState mergeState) throws IOException;
/** Write field for merging */
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
IOUtils.close(mergeOneFieldToIndex(fieldInfo, mergeState));
}
/** Called once at the end before close */
public abstract void finish() throws IOException;
/** Flush all buffered data on disk * */
public abstract void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException;
}

View File

@ -94,8 +94,6 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
} }
} }
public abstract Bits getAcceptOrds(Bits acceptDocs);
/** /**
* Dense vector values that are stored off-heap. This is the most common case when every doc has a * Dense vector values that are stored off-heap. This is the most common case when every doc has a
* vector. * vector.

View File

@ -88,8 +88,6 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
} }
} }
public abstract Bits getAcceptOrds(Bits acceptDocs);
/** /**
* Dense vector values that are stored off-heap. This is the most common case when every doc has a * Dense vector values that are stored off-heap. This is the most common case when every doc has a
* vector. * vector.

View File

@ -0,0 +1,98 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene99;
import java.io.IOException;
import org.apache.lucene.codecs.FlatVectorsFormat;
import org.apache.lucene.codecs.FlatVectorsReader;
import org.apache.lucene.codecs.FlatVectorsWriter;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.IndexOutput;
/**
* Lucene 9.9 flat vector format, which encodes numeric vector values
*
* <h2>.vec (vector data) file</h2>
*
* <p>For each field:
*
* <ul>
* <li>Vector data ordered by field, document ordinal, and vector dimension. When the
* vectorEncoding is BYTE, each sample is stored as a single byte. When it is FLOAT32, each
* sample is stored as an IEEE float in little-endian byte order.
* <li>DocIds encoded by {@link IndexedDISI#writeBitSet(DocIdSetIterator, IndexOutput, byte)},
* note that only in sparse case
* <li>OrdToDoc was encoded by {@link org.apache.lucene.util.packed.DirectMonotonicWriter}, note
* that only in sparse case
* </ul>
*
* <h2>.vemf (vector metadata) file</h2>
*
* <p>For each field:
*
* <ul>
* <li><b>[int32]</b> field number
* <li><b>[int32]</b> vector similarity function ordinal
* <li><b>[vlong]</b> offset to this field's vectors in the .vec file
* <li><b>[vlong]</b> length of this field's vectors, in bytes
* <li><b>[vint]</b> dimension of this field's vectors
* <li><b>[int]</b> the number of documents having values for this field
* <li><b>[int8]</b> if equals to -1, dense all documents have values for a field. If equals to
* 0, sparse some documents missing values.
* <li>DocIds were encoded by {@link IndexedDISI#writeBitSet(DocIdSetIterator, IndexOutput, byte)}
* <li>OrdToDoc was encoded by {@link org.apache.lucene.util.packed.DirectMonotonicWriter}, note
* that only in sparse case
* </ul>
*
* @lucene.experimental
*/
public final class Lucene99FlatVectorsFormat extends FlatVectorsFormat {
static final String META_CODEC_NAME = "Lucene99FlatVectorsFormatMeta";
static final String VECTOR_DATA_CODEC_NAME = "Lucene99FlatVectorsFormatData";
static final String META_EXTENSION = "vemf";
static final String VECTOR_DATA_EXTENSION = "vec";
public static final int VERSION_START = 0;
public static final int VERSION_CURRENT = VERSION_START;
static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16;
/** Constructs a format */
public Lucene99FlatVectorsFormat() {
super();
}
@Override
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99FlatVectorsWriter(state);
}
@Override
public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new Lucene99FlatVectorsReader(state);
}
@Override
public String toString() {
return "Lucene99FlatVectorsFormat()";
}
}

View File

@ -0,0 +1,333 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene99;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.FlatVectorsReader;
import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues;
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
/**
* Reads vectors from the index segments.
*
* @lucene.experimental
*/
public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
private static final long SHALLOW_SIZE =
RamUsageEstimator.shallowSizeOfInstance(Lucene99FlatVectorsFormat.class);
private final Map<String, FieldEntry> fields = new HashMap<>();
private final IndexInput vectorData;
Lucene99FlatVectorsReader(SegmentReadState state) throws IOException {
int versionMeta = readMetadata(state);
boolean success = false;
try {
vectorData =
openDataInput(
state,
versionMeta,
Lucene99FlatVectorsFormat.VECTOR_DATA_EXTENSION,
Lucene99FlatVectorsFormat.VECTOR_DATA_CODEC_NAME);
success = true;
} finally {
if (success == false) {
IOUtils.closeWhileHandlingException(this);
}
}
}
private int readMetadata(SegmentReadState state) throws IOException {
String metaFileName =
IndexFileNames.segmentFileName(
state.segmentInfo.name, state.segmentSuffix, Lucene99FlatVectorsFormat.META_EXTENSION);
int versionMeta = -1;
try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) {
Throwable priorE = null;
try {
versionMeta =
CodecUtil.checkIndexHeader(
meta,
Lucene99FlatVectorsFormat.META_CODEC_NAME,
Lucene99FlatVectorsFormat.VERSION_START,
Lucene99FlatVectorsFormat.VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix);
readFields(meta, state.fieldInfos);
} catch (Throwable exception) {
priorE = exception;
} finally {
CodecUtil.checkFooter(meta, priorE);
}
}
return versionMeta;
}
private static IndexInput openDataInput(
SegmentReadState state, int versionMeta, String fileExtension, String codecName)
throws IOException {
String fileName =
IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension);
IndexInput in = state.directory.openInput(fileName, state.context);
boolean success = false;
try {
int versionVectorData =
CodecUtil.checkIndexHeader(
in,
codecName,
Lucene99FlatVectorsFormat.VERSION_START,
Lucene99FlatVectorsFormat.VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix);
if (versionMeta != versionVectorData) {
throw new CorruptIndexException(
"Format versions mismatch: meta="
+ versionMeta
+ ", "
+ codecName
+ "="
+ versionVectorData,
in);
}
CodecUtil.retrieveChecksum(in);
success = true;
return in;
} finally {
if (success == false) {
IOUtils.closeWhileHandlingException(in);
}
}
}
private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException {
for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) {
FieldInfo info = infos.fieldInfo(fieldNumber);
if (info == null) {
throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
}
FieldEntry fieldEntry = readField(meta);
validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry);
}
}
private void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) {
int dimension = info.getVectorDimension();
if (dimension != fieldEntry.dimension) {
throw new IllegalStateException(
"Inconsistent vector dimension for field=\""
+ info.name
+ "\"; "
+ dimension
+ " != "
+ fieldEntry.dimension);
}
int byteSize =
switch (info.getVectorEncoding()) {
case BYTE -> Byte.BYTES;
case FLOAT32 -> Float.BYTES;
};
long vectorBytes = Math.multiplyExact((long) dimension, byteSize);
long numBytes = Math.multiplyExact(vectorBytes, fieldEntry.size);
if (numBytes != fieldEntry.vectorDataLength) {
throw new IllegalStateException(
"Vector data length "
+ fieldEntry.vectorDataLength
+ " not matching size="
+ fieldEntry.size
+ " * dim="
+ dimension
+ " * byteSize="
+ byteSize
+ " = "
+ numBytes);
}
}
private VectorSimilarityFunction readSimilarityFunction(DataInput input) throws IOException {
int similarityFunctionId = input.readInt();
if (similarityFunctionId < 0
|| similarityFunctionId >= VectorSimilarityFunction.values().length) {
throw new CorruptIndexException(
"Invalid similarity function id: " + similarityFunctionId, input);
}
return VectorSimilarityFunction.values()[similarityFunctionId];
}
private VectorEncoding readVectorEncoding(DataInput input) throws IOException {
int encodingId = input.readInt();
if (encodingId < 0 || encodingId >= VectorEncoding.values().length) {
throw new CorruptIndexException("Invalid vector encoding id: " + encodingId, input);
}
return VectorEncoding.values()[encodingId];
}
private FieldEntry readField(IndexInput input) throws IOException {
VectorEncoding vectorEncoding = readVectorEncoding(input);
VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
return new FieldEntry(input, vectorEncoding, similarityFunction);
}
@Override
public long ramBytesUsed() {
return Lucene99FlatVectorsReader.SHALLOW_SIZE
+ RamUsageEstimator.sizeOfMap(
fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class));
}
@Override
public void checkIntegrity() throws IOException {
CodecUtil.checksumEntireFile(vectorData);
}
@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
throw new IllegalArgumentException(
"field=\""
+ field
+ "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ VectorEncoding.FLOAT32);
}
return OffHeapFloatVectorValues.load(
fieldEntry.ordToDoc,
fieldEntry.vectorEncoding,
fieldEntry.dimension,
fieldEntry.vectorDataOffset,
fieldEntry.vectorDataLength,
vectorData);
}
@Override
public ByteVectorValues getByteVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
throw new IllegalArgumentException(
"field=\""
+ field
+ "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ VectorEncoding.FLOAT32);
}
return OffHeapByteVectorValues.load(
fieldEntry.ordToDoc,
fieldEntry.vectorEncoding,
fieldEntry.dimension,
fieldEntry.vectorDataOffset,
fieldEntry.vectorDataLength,
vectorData);
}
@Override
public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return null;
}
return RandomVectorScorer.createFloats(
OffHeapFloatVectorValues.load(
fieldEntry.ordToDoc,
fieldEntry.vectorEncoding,
fieldEntry.dimension,
fieldEntry.vectorDataOffset,
fieldEntry.vectorDataLength,
vectorData),
fieldEntry.similarityFunction,
target);
}
@Override
public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
return null;
}
return RandomVectorScorer.createBytes(
OffHeapByteVectorValues.load(
fieldEntry.ordToDoc,
fieldEntry.vectorEncoding,
fieldEntry.dimension,
fieldEntry.vectorDataOffset,
fieldEntry.vectorDataLength,
vectorData),
fieldEntry.similarityFunction,
target);
}
@Override
public void close() throws IOException {
IOUtils.close(vectorData);
}
private static class FieldEntry implements Accountable {
private static final long SHALLOW_SIZE =
RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class);
final VectorSimilarityFunction similarityFunction;
final VectorEncoding vectorEncoding;
final int dimension;
final long vectorDataOffset;
final long vectorDataLength;
final int size;
final OrdToDocDISIReaderConfiguration ordToDoc;
FieldEntry(
IndexInput input,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction)
throws IOException {
this.similarityFunction = similarityFunction;
this.vectorEncoding = vectorEncoding;
vectorDataOffset = input.readVLong();
vectorDataLength = input.readVLong();
dimension = input.readVInt();
size = input.readInt();
ordToDoc = OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size);
}
@Override
public long ramBytesUsed() {
return SHALLOW_SIZE + RamUsageEstimator.sizeOf(ordToDoc);
}
}
}

View File

@ -0,0 +1,508 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene99;
import static org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.Closeable;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.List;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.FlatFieldVectorsWriter;
import org.apache.lucene.codecs.FlatVectorsWriter;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues;
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
/**
* Writes vector values to index segments.
*
* @lucene.experimental
*/
public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter {
private static final long SHALLLOW_RAM_BYTES_USED =
RamUsageEstimator.shallowSizeOfInstance(Lucene99FlatVectorsWriter.class);
private final SegmentWriteState segmentWriteState;
private final IndexOutput meta, vectorData;
private final List<FieldWriter<?>> fields = new ArrayList<>();
private boolean finished;
Lucene99FlatVectorsWriter(SegmentWriteState state) throws IOException {
segmentWriteState = state;
String metaFileName =
IndexFileNames.segmentFileName(
state.segmentInfo.name, state.segmentSuffix, Lucene99FlatVectorsFormat.META_EXTENSION);
String vectorDataFileName =
IndexFileNames.segmentFileName(
state.segmentInfo.name,
state.segmentSuffix,
Lucene99FlatVectorsFormat.VECTOR_DATA_EXTENSION);
boolean success = false;
try {
meta = state.directory.createOutput(metaFileName, state.context);
vectorData = state.directory.createOutput(vectorDataFileName, state.context);
CodecUtil.writeIndexHeader(
meta,
Lucene99FlatVectorsFormat.META_CODEC_NAME,
Lucene99FlatVectorsFormat.VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix);
CodecUtil.writeIndexHeader(
vectorData,
Lucene99FlatVectorsFormat.VECTOR_DATA_CODEC_NAME,
Lucene99FlatVectorsFormat.VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix);
success = true;
} finally {
if (success == false) {
IOUtils.closeWhileHandlingException(this);
}
}
}
@Override
public FlatFieldVectorsWriter<?> addField(
FieldInfo fieldInfo, KnnFieldVectorsWriter<?> indexWriter) throws IOException {
FieldWriter<?> newField = FieldWriter.create(fieldInfo, indexWriter);
fields.add(newField);
return newField;
}
@Override
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
for (FieldWriter<?> field : fields) {
if (sortMap == null) {
writeField(field, maxDoc);
} else {
writeSortingField(field, maxDoc, sortMap);
}
}
}
@Override
public void finish() throws IOException {
if (finished) {
throw new IllegalStateException("already finished");
}
finished = true;
if (meta != null) {
// write end of fields marker
meta.writeInt(-1);
CodecUtil.writeFooter(meta);
}
if (vectorData != null) {
CodecUtil.writeFooter(vectorData);
}
}
@Override
public long ramBytesUsed() {
long total = SHALLLOW_RAM_BYTES_USED;
for (FieldWriter<?> field : fields) {
total += field.ramBytesUsed();
}
return total;
}
private void writeField(FieldWriter<?> fieldData, int maxDoc) throws IOException {
// write vector values
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
switch (fieldData.fieldInfo.getVectorEncoding()) {
case BYTE -> writeByteVectors(fieldData);
case FLOAT32 -> writeFloat32Vectors(fieldData);
}
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
writeMeta(
fieldData.fieldInfo, maxDoc, vectorDataOffset, vectorDataLength, fieldData.docsWithField);
}
private void writeFloat32Vectors(FieldWriter<?> fieldData) throws IOException {
final ByteBuffer buffer =
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
for (Object v : fieldData.vectors) {
buffer.asFloatBuffer().put((float[]) v);
vectorData.writeBytes(buffer.array(), buffer.array().length);
}
}
private void writeByteVectors(FieldWriter<?> fieldData) throws IOException {
for (Object v : fieldData.vectors) {
byte[] vector = (byte[]) v;
vectorData.writeBytes(vector, vector.length);
}
}
private void writeSortingField(FieldWriter<?> fieldData, int maxDoc, Sorter.DocMap sortMap)
throws IOException {
final int[] docIdOffsets = new int[sortMap.size()];
int offset = 1; // 0 means no vector for this (field, document)
DocIdSetIterator iterator = fieldData.docsWithField.iterator();
for (int docID = iterator.nextDoc();
docID != DocIdSetIterator.NO_MORE_DOCS;
docID = iterator.nextDoc()) {
int newDocID = sortMap.oldToNew(docID);
docIdOffsets[newDocID] = offset++;
}
DocsWithFieldSet newDocsWithField = new DocsWithFieldSet();
final int[] ordMap = new int[offset - 1]; // new ord to old ord
int ord = 0;
int doc = 0;
for (int docIdOffset : docIdOffsets) {
if (docIdOffset != 0) {
ordMap[ord] = docIdOffset - 1;
newDocsWithField.add(doc);
ord++;
}
doc++;
}
// write vector values
long vectorDataOffset =
switch (fieldData.fieldInfo.getVectorEncoding()) {
case BYTE -> writeSortedByteVectors(fieldData, ordMap);
case FLOAT32 -> writeSortedFloat32Vectors(fieldData, ordMap);
};
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
writeMeta(fieldData.fieldInfo, maxDoc, vectorDataOffset, vectorDataLength, newDocsWithField);
}
private long writeSortedFloat32Vectors(FieldWriter<?> fieldData, int[] ordMap)
throws IOException {
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
final ByteBuffer buffer =
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
for (int ordinal : ordMap) {
float[] vector = (float[]) fieldData.vectors.get(ordinal);
buffer.asFloatBuffer().put(vector);
vectorData.writeBytes(buffer.array(), buffer.array().length);
}
return vectorDataOffset;
}
private long writeSortedByteVectors(FieldWriter<?> fieldData, int[] ordMap) throws IOException {
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
for (int ordinal : ordMap) {
byte[] vector = (byte[]) fieldData.vectors.get(ordinal);
vectorData.writeBytes(vector, vector.length);
}
return vectorDataOffset;
}
@Override
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
// Since we know we will not be searching for additional indexing, we can just write the
// the vectors directly to the new segment.
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
// No need to use temporary file as we don't have to re-open for reading
DocsWithFieldSet docsWithField =
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> writeByteVectorData(
vectorData,
KnnVectorsWriter.MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState));
case FLOAT32 -> writeVectorData(
vectorData,
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState));
};
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
writeMeta(
fieldInfo,
segmentWriteState.segmentInfo.maxDoc(),
vectorDataOffset,
vectorDataLength,
docsWithField);
}
@Override
public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(
FieldInfo fieldInfo, MergeState mergeState) throws IOException {
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
IndexOutput tempVectorData =
segmentWriteState.directory.createTempOutput(
vectorData.getName(), "temp", segmentWriteState.context);
IndexInput vectorDataInput = null;
boolean success = false;
try {
// write the vector data to a temporary file
DocsWithFieldSet docsWithField =
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> writeByteVectorData(
tempVectorData,
KnnVectorsWriter.MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState));
case FLOAT32 -> writeVectorData(
tempVectorData,
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState));
};
CodecUtil.writeFooter(tempVectorData);
IOUtils.close(tempVectorData);
// copy the temporary file vectors to the actual data file
vectorDataInput =
segmentWriteState.directory.openInput(
tempVectorData.getName(), segmentWriteState.context);
vectorData.copyBytes(vectorDataInput, vectorDataInput.length() - CodecUtil.footerLength());
CodecUtil.retrieveChecksum(vectorDataInput);
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
writeMeta(
fieldInfo,
segmentWriteState.segmentInfo.maxDoc(),
vectorDataOffset,
vectorDataLength,
docsWithField);
success = true;
final IndexInput finalVectorDataInput = vectorDataInput;
final RandomVectorScorerSupplier randomVectorScorerSupplier =
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> RandomVectorScorerSupplier.createBytes(
new OffHeapByteVectorValues.DenseOffHeapVectorValues(
fieldInfo.getVectorDimension(),
docsWithField.cardinality(),
finalVectorDataInput,
fieldInfo.getVectorDimension() * Byte.BYTES),
fieldInfo.getVectorSimilarityFunction());
case FLOAT32 -> RandomVectorScorerSupplier.createFloats(
new OffHeapFloatVectorValues.DenseOffHeapVectorValues(
fieldInfo.getVectorDimension(),
docsWithField.cardinality(),
finalVectorDataInput,
fieldInfo.getVectorDimension() * Float.BYTES),
fieldInfo.getVectorSimilarityFunction());
};
return new FlatCloseableRandomVectorScorerSupplier(
() -> {
IOUtils.close(finalVectorDataInput);
segmentWriteState.directory.deleteFile(tempVectorData.getName());
},
docsWithField.cardinality(),
randomVectorScorerSupplier);
} finally {
if (success == false) {
IOUtils.closeWhileHandlingException(vectorDataInput, tempVectorData);
IOUtils.deleteFilesIgnoringExceptions(
segmentWriteState.directory, tempVectorData.getName());
}
}
}
private void writeMeta(
FieldInfo field,
int maxDoc,
long vectorDataOffset,
long vectorDataLength,
DocsWithFieldSet docsWithField)
throws IOException {
meta.writeInt(field.number);
meta.writeInt(field.getVectorEncoding().ordinal());
meta.writeInt(field.getVectorSimilarityFunction().ordinal());
meta.writeVLong(vectorDataOffset);
meta.writeVLong(vectorDataLength);
meta.writeVInt(field.getVectorDimension());
// write docIDs
int count = docsWithField.cardinality();
meta.writeInt(count);
OrdToDocDISIReaderConfiguration.writeStoredMeta(
DIRECT_MONOTONIC_BLOCK_SHIFT, meta, vectorData, count, maxDoc, docsWithField);
}
/**
* Writes the byte vector values to the output and returns a set of documents that contains
* vectors.
*/
private static DocsWithFieldSet writeByteVectorData(
IndexOutput output, ByteVectorValues byteVectorValues) throws IOException {
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
for (int docV = byteVectorValues.nextDoc();
docV != NO_MORE_DOCS;
docV = byteVectorValues.nextDoc()) {
// write vector
byte[] binaryValue = byteVectorValues.vectorValue();
assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize;
output.writeBytes(binaryValue, binaryValue.length);
docsWithField.add(docV);
}
return docsWithField;
}
/**
* Writes the vector values to the output and returns a set of documents that contains vectors.
*/
private static DocsWithFieldSet writeVectorData(
IndexOutput output, FloatVectorValues floatVectorValues) throws IOException {
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
ByteBuffer buffer =
ByteBuffer.allocate(floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize)
.order(ByteOrder.LITTLE_ENDIAN);
for (int docV = floatVectorValues.nextDoc();
docV != NO_MORE_DOCS;
docV = floatVectorValues.nextDoc()) {
// write vector
float[] value = floatVectorValues.vectorValue();
buffer.asFloatBuffer().put(value);
output.writeBytes(buffer.array(), buffer.limit());
docsWithField.add(docV);
}
return docsWithField;
}
@Override
public void close() throws IOException {
IOUtils.close(meta, vectorData);
}
private abstract static class FieldWriter<T> extends FlatFieldVectorsWriter<T> {
private static final long SHALLOW_RAM_BYTES_USED =
RamUsageEstimator.shallowSizeOfInstance(FieldWriter.class);
private final FieldInfo fieldInfo;
private final int dim;
private final DocsWithFieldSet docsWithField;
private final List<T> vectors;
private int lastDocID = -1;
@SuppressWarnings("unchecked")
static FieldWriter<?> create(FieldInfo fieldInfo, KnnFieldVectorsWriter<?> indexWriter) {
int dim = fieldInfo.getVectorDimension();
return switch (fieldInfo.getVectorEncoding()) {
case BYTE -> new Lucene99FlatVectorsWriter.FieldWriter<>(
fieldInfo, (KnnFieldVectorsWriter<byte[]>) indexWriter) {
@Override
public byte[] copyValue(byte[] value) {
return ArrayUtil.copyOfSubArray(value, 0, dim);
}
};
case FLOAT32 -> new Lucene99FlatVectorsWriter.FieldWriter<>(
fieldInfo, (KnnFieldVectorsWriter<float[]>) indexWriter) {
@Override
public float[] copyValue(float[] value) {
return ArrayUtil.copyOfSubArray(value, 0, dim);
}
};
};
}
FieldWriter(FieldInfo fieldInfo, KnnFieldVectorsWriter<T> indexWriter) {
super(indexWriter);
this.fieldInfo = fieldInfo;
this.dim = fieldInfo.getVectorDimension();
this.docsWithField = new DocsWithFieldSet();
vectors = new ArrayList<>();
}
@Override
public void addValue(int docID, T vectorValue) throws IOException {
if (docID == lastDocID) {
throw new IllegalArgumentException(
"VectorValuesField \""
+ fieldInfo.name
+ "\" appears more than once in this document (only one value is allowed per field)");
}
assert docID > lastDocID;
T copy = copyValue(vectorValue);
docsWithField.add(docID);
vectors.add(copy);
lastDocID = docID;
if (indexingDelegate != null) {
indexingDelegate.addValue(docID, copy);
}
}
@Override
public long ramBytesUsed() {
long size = SHALLOW_RAM_BYTES_USED;
if (indexingDelegate != null) {
size += indexingDelegate.ramBytesUsed();
}
if (vectors.size() == 0) return size;
return size
+ docsWithField.ramBytesUsed()
+ (long) vectors.size()
* (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
+ (long) vectors.size()
* fieldInfo.getVectorDimension()
* fieldInfo.getVectorEncoding().byteSize;
}
}
static final class FlatCloseableRandomVectorScorerSupplier
implements CloseableRandomVectorScorerSupplier {
private final RandomVectorScorerSupplier supplier;
private final Closeable onClose;
private final int numVectors;
FlatCloseableRandomVectorScorerSupplier(
Closeable onClose, int numVectors, RandomVectorScorerSupplier supplier) {
this.onClose = onClose;
this.supplier = supplier;
this.numVectors = numVectors;
}
@Override
public RandomVectorScorer scorer(int ord) throws IOException {
return supplier.scorer(ord);
}
@Override
public RandomVectorScorerSupplier copy() throws IOException {
return supplier.copy();
}
@Override
public void close() throws IOException {
onClose.close();
}
@Override
public int totalVectorCount() {
return numVectors;
}
}
}

View File

@ -0,0 +1,159 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene99;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN;
import java.io.IOException;
import java.util.concurrent.ExecutorService;
import org.apache.lucene.codecs.FlatVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.util.hnsw.HnswGraph;
/**
* Lucene 9.9 vector format, which encodes numeric vector values into an associated graph connecting
* the documents having values. The graph is used to power HNSW search. The format consists of two
* files, and uses {@link Lucene99ScalarQuantizedVectorsFormat} to store the actual vectors: For
* details on graph storage and file extensions, see {@link Lucene99HnswVectorsFormat}.
*
* @lucene.experimental
*/
public final class Lucene99HnswScalarQuantizedVectorsFormat extends KnnVectorsFormat {
/**
* Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to
* {@link Lucene99HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details.
*/
private final int maxConn;
/**
* The number of candidate neighbors to track while searching the graph for each newly inserted
* node. Defaults to to {@link Lucene99HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link
* HnswGraph} for details.
*/
private final int beamWidth;
/** The format for storing, reading, merging vectors on disk */
private final FlatVectorsFormat flatVectorsFormat;
private final int numMergeWorkers;
private final ExecutorService mergeExec;
/** Constructs a format using default graph construction parameters */
public Lucene99HnswScalarQuantizedVectorsFormat() {
this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null, null);
}
/**
* Constructs a format using the given graph construction parameters.
*
* @param maxConn the maximum number of connections to a node in the HNSW graph
* @param beamWidth the size of the queue maintained during graph construction.
*/
public Lucene99HnswScalarQuantizedVectorsFormat(int maxConn, int beamWidth) {
this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null, null);
}
/**
* Constructs a format using the given graph construction parameters and scalar quantization.
*
* @param maxConn the maximum number of connections to a node in the HNSW graph
* @param beamWidth the size of the queue maintained during graph construction.
* @param numMergeWorkers number of workers (threads) that will be used when doing merge. If
* larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec
* @param configuredQuantile the quantile for scalar quantizing the vectors, when `null` it is
* calculated based on the vector field dimensions.
* @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are
* generated by this format to do the merge
*/
public Lucene99HnswScalarQuantizedVectorsFormat(
int maxConn,
int beamWidth,
int numMergeWorkers,
Float configuredQuantile,
ExecutorService mergeExec) {
super("Lucene99HnswScalarQuantizedVectorsFormat");
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
throw new IllegalArgumentException(
"maxConn must be positive and less than or equal to"
+ MAXIMUM_MAX_CONN
+ "; maxConn="
+ maxConn);
}
if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) {
throw new IllegalArgumentException(
"beamWidth must be positive and less than or equal to"
+ MAXIMUM_BEAM_WIDTH
+ "; beamWidth="
+ beamWidth);
}
this.maxConn = maxConn;
this.beamWidth = beamWidth;
if (numMergeWorkers > 1 && mergeExec == null) {
throw new IllegalArgumentException(
"No executor service passed in when " + numMergeWorkers + " merge workers are requested");
}
if (numMergeWorkers == 1 && mergeExec != null) {
throw new IllegalArgumentException(
"No executor service is needed as we'll use single thread to merge");
}
this.numMergeWorkers = numMergeWorkers;
this.mergeExec = mergeExec;
this.flatVectorsFormat = new Lucene99ScalarQuantizedVectorsFormat(configuredQuantile);
}
@Override
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99HnswVectorsWriter(
state,
maxConn,
beamWidth,
flatVectorsFormat.fieldsWriter(state),
numMergeWorkers,
mergeExec);
}
@Override
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state));
}
@Override
public int getMaxDimensions(String fieldName) {
return 1024;
}
@Override
public String toString() {
return "Lucene99HnswScalarQuantizedVectorsFormat(name=Lucene99HnswScalarQuantizedVectorsFormat, maxConn="
+ maxConn
+ ", beamWidth="
+ beamWidth
+ ", flatVectorFormat="
+ flatVectorsFormat
+ ")";
}
}

View File

@ -19,6 +19,7 @@ package org.apache.lucene.codecs.lucene99;
import java.io.IOException; import java.io.IOException;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import org.apache.lucene.codecs.FlatVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.KnnVectorsWriter;
@ -30,23 +31,9 @@ import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.lucene.util.hnsw.HnswGraph;
/** /**
* Lucene 9.9 vector format, which encodes numeric vector values and an optional associated graph * Lucene 9.9 vector format, which encodes numeric vector values into an associated graph connecting
* connecting the documents having values. The graph is used to power HNSW search. The format * the documents having values. The graph is used to power HNSW search. The format consists of two
* consists of three files, with an optional fourth file: * files, and requires a {@link FlatVectorsFormat} to store the actual vectors:
*
* <h2>.vec (vector data) file</h2>
*
* <p>For each field:
*
* <ul>
* <li>Vector data ordered by field, document ordinal, and vector dimension. When the
* vectorEncoding is BYTE, each sample is stored as a single byte. When it is FLOAT32, each
* sample is stored as an IEEE float in little-endian byte order.
* <li>DocIds encoded by {@link IndexedDISI#writeBitSet(DocIdSetIterator, IndexOutput, byte)},
* note that only in sparse case
* <li>OrdToDoc was encoded by {@link org.apache.lucene.util.packed.DirectMonotonicWriter}, note
* that only in sparse case
* </ul>
* *
* <h2>.vex (vector index)</h2> * <h2>.vex (vector index)</h2>
* *
@ -74,14 +61,6 @@ import org.apache.lucene.util.hnsw.HnswGraph;
* <ul> * <ul>
* <li><b>[int32]</b> field number * <li><b>[int32]</b> field number
* <li><b>[int32]</b> vector similarity function ordinal * <li><b>[int32]</b> vector similarity function ordinal
* <li><b>[byte]</b> if equals to 1 indicates if the field is for quantized vectors
* <li><b>[int32]</b> if quantized: the configured quantile float int bits.
* <li><b>[int32]</b> if quantized: the calculated lower quantile float int32 bits.
* <li><b>[int32]</b> if quantized: the calculated upper quantile float int32 bits.
* <li><b>[vlong]</b> if quantized: offset to this field's vectors in the .veq file
* <li><b>[vlong]</b> if quantized: length of this field's vectors, in bytes in the .veq file
* <li><b>[vlong]</b> offset to this field's vectors in the .vec file
* <li><b>[vlong]</b> length of this field's vectors, in bytes
* <li><b>[vlong]</b> offset to this field's index in the .vex file * <li><b>[vlong]</b> offset to this field's index in the .vex file
* <li><b>[vlong]</b> length of this field's index data, in bytes * <li><b>[vlong]</b> length of this field's index data, in bytes
* <li><b>[vint]</b> dimension of this field's vectors * <li><b>[vint]</b> dimension of this field's vectors
@ -101,29 +80,13 @@ import org.apache.lucene.util.hnsw.HnswGraph;
* </ul> * </ul>
* </ul> * </ul>
* *
* <h2>.veq (quantized vector data) file</h2>
*
* <p>For each field:
*
* <ul>
* <li>Vector data ordered by field, document ordinal, and vector dimension. Each vector dimension
* is stored as a single byte and every vector has a single float32 value for scoring
* corrections.
* <li>DocIds encoded by {@link IndexedDISI#writeBitSet(DocIdSetIterator, IndexOutput, byte)},
* note that only in sparse case
* <li>OrdToDoc was encoded by {@link org.apache.lucene.util.packed.DirectMonotonicWriter}, note
* that only in sparse case
* </ul>
*
* @lucene.experimental * @lucene.experimental
*/ */
public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat { public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
static final String META_CODEC_NAME = "Lucene99HnswVectorsFormatMeta"; static final String META_CODEC_NAME = "Lucene99HnswVectorsFormatMeta";
static final String VECTOR_DATA_CODEC_NAME = "Lucene99HnswVectorsFormatData";
static final String VECTOR_INDEX_CODEC_NAME = "Lucene99HnswVectorsFormatIndex"; static final String VECTOR_INDEX_CODEC_NAME = "Lucene99HnswVectorsFormatIndex";
static final String META_EXTENSION = "vem"; static final String META_EXTENSION = "vem";
static final String VECTOR_DATA_EXTENSION = "vec";
static final String VECTOR_INDEX_EXTENSION = "vex"; static final String VECTOR_INDEX_EXTENSION = "vex";
public static final int VERSION_START = 0; public static final int VERSION_START = 0;
@ -135,7 +98,7 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
* <p>NOTE: We eagerly populate `float[MAX_CONN*2]` and `int[MAX_CONN*2]`, so exceptionally large * <p>NOTE: We eagerly populate `float[MAX_CONN*2]` and `int[MAX_CONN*2]`, so exceptionally large
* numbers here will use an inordinate amount of heap * numbers here will use an inordinate amount of heap
*/ */
private static final int MAXIMUM_MAX_CONN = 512; static final int MAXIMUM_MAX_CONN = 512;
/** Default number of maximum connections per node */ /** Default number of maximum connections per node */
public static final int DEFAULT_MAX_CONN = 16; public static final int DEFAULT_MAX_CONN = 16;
@ -145,7 +108,7 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
* maximum value preserves the ratio of the DEFAULT_BEAM_WIDTH/DEFAULT_MAX_CONN i.e. `6.25 * 16 = * maximum value preserves the ratio of the DEFAULT_BEAM_WIDTH/DEFAULT_MAX_CONN i.e. `6.25 * 16 =
* 3200` * 3200`
*/ */
private static final int MAXIMUM_BEAM_WIDTH = 3200; static final int MAXIMUM_BEAM_WIDTH = 3200;
/** /**
* Default number of the size of the queue maintained while searching during a graph construction. * Default number of the size of the queue maintained while searching during a graph construction.
@ -170,20 +133,15 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
*/ */
private final int beamWidth; private final int beamWidth;
/** Should this codec scalar quantize float32 vectors and use this format */ /** The format for storing, reading, merging vectors on disk */
private final Lucene99ScalarQuantizedVectorsFormat scalarQuantizedVectorsFormat; private static final FlatVectorsFormat flatVectorsFormat = new Lucene99FlatVectorsFormat();
private final int numMergeWorkers; private final int numMergeWorkers;
private final ExecutorService mergeExec; private final ExecutorService mergeExec;
/** Constructs a format using default graph construction parameters */ /** Constructs a format using default graph construction parameters */
public Lucene99HnswVectorsFormat() { public Lucene99HnswVectorsFormat() {
this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, null); this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null);
}
public Lucene99HnswVectorsFormat(
int maxConn, int beamWidth, Lucene99ScalarQuantizedVectorsFormat scalarQuantize) {
this(maxConn, beamWidth, scalarQuantize, DEFAULT_NUM_MERGE_WORKER, null);
} }
/** /**
@ -193,7 +151,7 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
* @param beamWidth the size of the queue maintained during graph construction. * @param beamWidth the size of the queue maintained during graph construction.
*/ */
public Lucene99HnswVectorsFormat(int maxConn, int beamWidth) { public Lucene99HnswVectorsFormat(int maxConn, int beamWidth) {
this(maxConn, beamWidth, null); this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null);
} }
/** /**
@ -201,18 +159,13 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
* *
* @param maxConn the maximum number of connections to a node in the HNSW graph * @param maxConn the maximum number of connections to a node in the HNSW graph
* @param beamWidth the size of the queue maintained during graph construction. * @param beamWidth the size of the queue maintained during graph construction.
* @param scalarQuantize the scalar quantization format
* @param numMergeWorkers number of workers (threads) that will be used when doing merge. If * @param numMergeWorkers number of workers (threads) that will be used when doing merge. If
* larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec * larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec
* @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are * @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are
* generated by this format to do the merge * generated by this format to do the merge
*/ */
public Lucene99HnswVectorsFormat( public Lucene99HnswVectorsFormat(
int maxConn, int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) {
int beamWidth,
Lucene99ScalarQuantizedVectorsFormat scalarQuantize,
int numMergeWorkers,
ExecutorService mergeExec) {
super("Lucene99HnswVectorsFormat"); super("Lucene99HnswVectorsFormat");
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) { if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
@ -228,6 +181,8 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
+ "; beamWidth=" + "; beamWidth="
+ beamWidth); + beamWidth);
} }
this.maxConn = maxConn;
this.beamWidth = beamWidth;
if (numMergeWorkers > 1 && mergeExec == null) { if (numMergeWorkers > 1 && mergeExec == null) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"No executor service passed in when " + numMergeWorkers + " merge workers are requested"); "No executor service passed in when " + numMergeWorkers + " merge workers are requested");
@ -236,9 +191,6 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"No executor service is needed as we'll use single thread to merge"); "No executor service is needed as we'll use single thread to merge");
} }
this.maxConn = maxConn;
this.beamWidth = beamWidth;
this.scalarQuantizedVectorsFormat = scalarQuantize;
this.numMergeWorkers = numMergeWorkers; this.numMergeWorkers = numMergeWorkers;
this.mergeExec = mergeExec; this.mergeExec = mergeExec;
} }
@ -246,12 +198,17 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
@Override @Override
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99HnswVectorsWriter( return new Lucene99HnswVectorsWriter(
state, maxConn, beamWidth, scalarQuantizedVectorsFormat, numMergeWorkers, mergeExec); state,
maxConn,
beamWidth,
flatVectorsFormat.fieldsWriter(state),
numMergeWorkers,
mergeExec);
} }
@Override @Override
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new Lucene99HnswVectorsReader(state); return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state));
} }
@Override @Override
@ -265,8 +222,8 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
+ maxConn + maxConn
+ ", beamWidth=" + ", beamWidth="
+ beamWidth + beamWidth
+ ", quantizer=" + ", flatVectorFormat="
+ (scalarQuantizedVectorsFormat == null ? "none" : scalarQuantizedVectorsFormat.toString()) + flatVectorsFormat
+ ")"; + ")";
} }
} }

View File

@ -24,11 +24,9 @@ import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.FlatVectorsReader;
import org.apache.lucene.codecs.HnswGraphProvider; import org.apache.lucene.codecs.HnswGraphProvider;
import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues;
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CorruptIndexException; import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfo;
@ -67,49 +65,14 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
private final FieldInfos fieldInfos; private final FieldInfos fieldInfos;
private final Map<String, FieldEntry> fields = new HashMap<>(); private final Map<String, FieldEntry> fields = new HashMap<>();
private final IndexInput vectorData;
private final IndexInput vectorIndex; private final IndexInput vectorIndex;
private final IndexInput quantizedVectorData; private final FlatVectorsReader flatVectorsReader;
private final Lucene99ScalarQuantizedVectorsReader quantizedVectorsReader;
Lucene99HnswVectorsReader(SegmentReadState state) throws IOException { Lucene99HnswVectorsReader(SegmentReadState state, FlatVectorsReader flatVectorsReader)
this.fieldInfos = state.fieldInfos; throws IOException {
int versionMeta = readMetadata(state); this.flatVectorsReader = flatVectorsReader;
boolean success = false; boolean success = false;
try { this.fieldInfos = state.fieldInfos;
vectorData =
openDataInput(
state,
versionMeta,
Lucene99HnswVectorsFormat.VECTOR_DATA_EXTENSION,
Lucene99HnswVectorsFormat.VECTOR_DATA_CODEC_NAME);
vectorIndex =
openDataInput(
state,
versionMeta,
Lucene99HnswVectorsFormat.VECTOR_INDEX_EXTENSION,
Lucene99HnswVectorsFormat.VECTOR_INDEX_CODEC_NAME);
if (fields.values().stream().anyMatch(FieldEntry::hasQuantizedVectors)) {
quantizedVectorData =
openDataInput(
state,
versionMeta,
Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_DATA_EXTENSION,
Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_DATA_CODEC_NAME);
quantizedVectorsReader = new Lucene99ScalarQuantizedVectorsReader(quantizedVectorData);
} else {
quantizedVectorData = null;
quantizedVectorsReader = null;
}
success = true;
} finally {
if (success == false) {
IOUtils.closeWhileHandlingException(this);
}
}
}
private int readMetadata(SegmentReadState state) throws IOException {
String metaFileName = String metaFileName =
IndexFileNames.segmentFileName( IndexFileNames.segmentFileName(
state.segmentInfo.name, state.segmentSuffix, Lucene99HnswVectorsFormat.META_EXTENSION); state.segmentInfo.name, state.segmentSuffix, Lucene99HnswVectorsFormat.META_EXTENSION);
@ -129,10 +92,30 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
} catch (Throwable exception) { } catch (Throwable exception) {
priorE = exception; priorE = exception;
} finally { } finally {
try {
CodecUtil.checkFooter(meta, priorE); CodecUtil.checkFooter(meta, priorE);
success = true;
} finally {
if (success == false) {
IOUtils.close(flatVectorsReader);
}
}
}
}
success = false;
try {
vectorIndex =
openDataInput(
state,
versionMeta,
Lucene99HnswVectorsFormat.VECTOR_INDEX_EXTENSION,
Lucene99HnswVectorsFormat.VECTOR_INDEX_CODEC_NAME);
success = true;
} finally {
if (success == false) {
IOUtils.closeWhileHandlingException(this);
} }
} }
return versionMeta;
} }
private static IndexInput openDataInput( private static IndexInput openDataInput(
@ -194,31 +177,6 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
+ " != " + " != "
+ fieldEntry.dimension); + fieldEntry.dimension);
} }
int byteSize =
switch (info.getVectorEncoding()) {
case BYTE -> Byte.BYTES;
case FLOAT32 -> Float.BYTES;
};
long vectorBytes = Math.multiplyExact((long) dimension, byteSize);
long numBytes = Math.multiplyExact(vectorBytes, fieldEntry.size);
if (numBytes != fieldEntry.vectorDataLength) {
throw new IllegalStateException(
"Vector data length "
+ fieldEntry.vectorDataLength
+ " not matching size="
+ fieldEntry.size
+ " * dim="
+ dimension
+ " * byteSize="
+ byteSize
+ " = "
+ numBytes);
}
if (fieldEntry.hasQuantizedVectors()) {
Lucene99ScalarQuantizedVectorsReader.validateFieldEntry(
info, fieldEntry.dimension, fieldEntry.size, fieldEntry.quantizedVectorDataLength);
}
} }
private VectorSimilarityFunction readSimilarityFunction(DataInput input) throws IOException { private VectorSimilarityFunction readSimilarityFunction(DataInput input) throws IOException {
@ -249,58 +207,24 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
public long ramBytesUsed() { public long ramBytesUsed() {
return Lucene99HnswVectorsReader.SHALLOW_SIZE return Lucene99HnswVectorsReader.SHALLOW_SIZE
+ RamUsageEstimator.sizeOfMap( + RamUsageEstimator.sizeOfMap(
fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class)); fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class))
+ flatVectorsReader.ramBytesUsed();
} }
@Override @Override
public void checkIntegrity() throws IOException { public void checkIntegrity() throws IOException {
CodecUtil.checksumEntireFile(vectorData); flatVectorsReader.checkIntegrity();
CodecUtil.checksumEntireFile(vectorIndex); CodecUtil.checksumEntireFile(vectorIndex);
if (quantizedVectorsReader != null) {
quantizedVectorsReader.checkIntegrity();
}
} }
@Override @Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException { public FloatVectorValues getFloatVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field); return flatVectorsReader.getFloatVectorValues(field);
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
throw new IllegalArgumentException(
"field=\""
+ field
+ "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ VectorEncoding.FLOAT32);
}
return OffHeapFloatVectorValues.load(
fieldEntry.ordToDoc,
fieldEntry.vectorEncoding,
fieldEntry.dimension,
fieldEntry.vectorDataOffset,
fieldEntry.vectorDataLength,
vectorData);
} }
@Override @Override
public ByteVectorValues getByteVectorValues(String field) throws IOException { public ByteVectorValues getByteVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field); return flatVectorsReader.getByteVectorValues(field);
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
throw new IllegalArgumentException(
"field=\""
+ field
+ "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ VectorEncoding.FLOAT32);
}
return OffHeapByteVectorValues.load(
fieldEntry.ordToDoc,
fieldEntry.vectorEncoding,
fieldEntry.dimension,
fieldEntry.vectorDataOffset,
fieldEntry.vectorDataLength,
vectorData);
} }
@Override @Override
@ -313,42 +237,12 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
|| fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) { || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return; return;
} }
if (fieldEntry.hasQuantizedVectors()) { RandomVectorScorer scorer = flatVectorsReader.getRandomVectorScorer(field, target);
OffHeapQuantizedByteVectorValues vectorValues =
quantizedVectorsReader.getQuantizedVectorValues(
fieldEntry.quantizedOrdToDoc,
fieldEntry.dimension,
fieldEntry.size,
fieldEntry.quantizedVectorDataOffset,
fieldEntry.quantizedVectorDataLength);
if (vectorValues == null) {
return;
}
RandomVectorScorer scorer =
new ScalarQuantizedRandomVectorScorer(
fieldEntry.similarityFunction, fieldEntry.scalarQuantizer, vectorValues, target);
HnswGraphSearcher.search( HnswGraphSearcher.search(
scorer, scorer,
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc), new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc),
getGraph(fieldEntry), getGraph(fieldEntry),
vectorValues.getAcceptOrds(acceptDocs)); scorer.getAcceptOrds(acceptDocs));
} else {
OffHeapFloatVectorValues vectorValues =
OffHeapFloatVectorValues.load(
fieldEntry.ordToDoc,
fieldEntry.vectorEncoding,
fieldEntry.dimension,
fieldEntry.vectorDataOffset,
fieldEntry.vectorDataLength,
vectorData);
RandomVectorScorer scorer =
RandomVectorScorer.createFloats(vectorValues, fieldEntry.similarityFunction, target);
HnswGraphSearcher.search(
scorer,
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
getGraph(fieldEntry),
vectorValues.getAcceptOrds(acceptDocs));
}
} }
@Override @Override
@ -361,22 +255,12 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
|| fieldEntry.vectorEncoding != VectorEncoding.BYTE) { || fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
return; return;
} }
RandomVectorScorer scorer = flatVectorsReader.getRandomVectorScorer(field, target);
OffHeapByteVectorValues vectorValues =
OffHeapByteVectorValues.load(
fieldEntry.ordToDoc,
fieldEntry.vectorEncoding,
fieldEntry.dimension,
fieldEntry.vectorDataOffset,
fieldEntry.vectorDataLength,
vectorData);
RandomVectorScorer scorer =
RandomVectorScorer.createBytes(vectorValues, fieldEntry.similarityFunction, target);
HnswGraphSearcher.search( HnswGraphSearcher.search(
scorer, scorer,
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc), new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc),
getGraph(fieldEntry), getGraph(fieldEntry),
vectorValues.getAcceptOrds(acceptDocs)); scorer.getAcceptOrds(acceptDocs));
} }
@Override @Override
@ -399,32 +283,23 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
@Override @Override
public void close() throws IOException { public void close() throws IOException {
IOUtils.close(vectorData, vectorIndex, quantizedVectorData); IOUtils.close(flatVectorsReader, vectorIndex);
} }
@Override @Override
public OffHeapQuantizedByteVectorValues getQuantizedVectorValues(String field) public QuantizedByteVectorValues getQuantizedVectorValues(String field) throws IOException {
throws IOException { if (flatVectorsReader instanceof QuantizedVectorsReader) {
FieldEntry fieldEntry = fields.get(field); return ((QuantizedVectorsReader) flatVectorsReader).getQuantizedVectorValues(field);
if (fieldEntry == null || fieldEntry.hasQuantizedVectors() == false) {
return null;
} }
assert quantizedVectorsReader != null && fieldEntry.quantizedOrdToDoc != null; return null;
return quantizedVectorsReader.getQuantizedVectorValues(
fieldEntry.quantizedOrdToDoc,
fieldEntry.dimension,
fieldEntry.size,
fieldEntry.quantizedVectorDataOffset,
fieldEntry.quantizedVectorDataLength);
} }
@Override @Override
public ScalarQuantizer getQuantizationState(String fieldName) { public ScalarQuantizer getQuantizationState(String field) {
FieldEntry field = fields.get(fieldName); if (flatVectorsReader instanceof QuantizedVectorsReader) {
if (field == null || field.hasQuantizedVectors() == false) { return ((QuantizedVectorsReader) flatVectorsReader).getQuantizationState(field);
return null;
} }
return field.scalarQuantizer; return null;
} }
static class FieldEntry implements Accountable { static class FieldEntry implements Accountable {
@ -432,8 +307,6 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class); RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class);
final VectorSimilarityFunction similarityFunction; final VectorSimilarityFunction similarityFunction;
final VectorEncoding vectorEncoding; final VectorEncoding vectorEncoding;
final long vectorDataOffset;
final long vectorDataLength;
final long vectorIndexOffset; final long vectorIndexOffset;
final long vectorIndexLength; final long vectorIndexLength;
final int M; final int M;
@ -446,13 +319,6 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
final long offsetsOffset; final long offsetsOffset;
final int offsetsBlockShift; final int offsetsBlockShift;
final long offsetsLength; final long offsetsLength;
final OrdToDocDISIReaderConfiguration ordToDoc;
final float configuredQuantile, lowerQuantile, upperQuantile;
final long quantizedVectorDataOffset, quantizedVectorDataLength;
final ScalarQuantizer scalarQuantizer;
final boolean isQuantized;
final OrdToDocDISIReaderConfiguration quantizedOrdToDoc;
FieldEntry( FieldEntry(
IndexInput input, IndexInput input,
@ -461,36 +327,10 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
throws IOException { throws IOException {
this.similarityFunction = similarityFunction; this.similarityFunction = similarityFunction;
this.vectorEncoding = vectorEncoding; this.vectorEncoding = vectorEncoding;
this.isQuantized = input.readByte() == 1;
// Has int8 quantization
if (isQuantized) {
configuredQuantile = Float.intBitsToFloat(input.readInt());
lowerQuantile = Float.intBitsToFloat(input.readInt());
upperQuantile = Float.intBitsToFloat(input.readInt());
quantizedVectorDataOffset = input.readVLong();
quantizedVectorDataLength = input.readVLong();
scalarQuantizer = new ScalarQuantizer(lowerQuantile, upperQuantile, configuredQuantile);
} else {
configuredQuantile = -1;
lowerQuantile = -1;
upperQuantile = -1;
quantizedVectorDataOffset = -1;
quantizedVectorDataLength = -1;
scalarQuantizer = null;
}
vectorDataOffset = input.readVLong();
vectorDataLength = input.readVLong();
vectorIndexOffset = input.readVLong(); vectorIndexOffset = input.readVLong();
vectorIndexLength = input.readVLong(); vectorIndexLength = input.readVLong();
dimension = input.readVInt(); dimension = input.readVInt();
size = input.readInt(); size = input.readInt();
if (isQuantized) {
quantizedOrdToDoc = OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size);
} else {
quantizedOrdToDoc = null;
}
ordToDoc = OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size);
// read nodes by level // read nodes by level
M = input.readVInt(); M = input.readVInt();
numLevels = input.readVInt(); numLevels = input.readVInt();
@ -526,16 +366,10 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
return size; return size;
} }
boolean hasQuantizedVectors() {
return isQuantized;
}
@Override @Override
public long ramBytesUsed() { public long ramBytesUsed() {
return SHALLOW_SIZE return SHALLOW_SIZE
+ Arrays.stream(nodesByLevel).mapToLong(nodes -> RamUsageEstimator.sizeOf(nodes)).sum() + Arrays.stream(nodesByLevel).mapToLong(nodes -> RamUsageEstimator.sizeOf(nodes)).sum()
+ RamUsageEstimator.sizeOf(ordToDoc)
+ (quantizedOrdToDoc == null ? 0 : RamUsageEstimator.sizeOf(quantizedOrdToDoc))
+ RamUsageEstimator.sizeOf(offsetsMeta); + RamUsageEstimator.sizeOf(offsetsMeta);
} }
} }

View File

@ -18,40 +18,27 @@
package org.apache.lucene.codecs.lucene99; package org.apache.lucene.codecs.lucene99;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT;
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_COMPONENT;
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException; import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.FlatVectorsWriter;
import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues;
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.MergeState; import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter; import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput; import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.InfoStream; import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.RamUsageEstimator; import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.ScalarQuantizer;
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier; import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.ConcurrentHnswMerger; import org.apache.lucene.util.hnsw.ConcurrentHnswMerger;
import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.lucene.util.hnsw.HnswGraph;
@ -62,7 +49,6 @@ import org.apache.lucene.util.hnsw.IncrementalHnswGraphMerger;
import org.apache.lucene.util.hnsw.NeighborArray; import org.apache.lucene.util.hnsw.NeighborArray;
import org.apache.lucene.util.hnsw.OnHeapHnswGraph; import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.packed.DirectMonotonicWriter; import org.apache.lucene.util.packed.DirectMonotonicWriter;
@ -73,11 +59,13 @@ import org.apache.lucene.util.packed.DirectMonotonicWriter;
*/ */
public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter { public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
private static final long SHALLOW_RAM_BYTES_USED =
RamUsageEstimator.shallowSizeOfInstance(Lucene99HnswVectorsWriter.class);
private final SegmentWriteState segmentWriteState; private final SegmentWriteState segmentWriteState;
private final IndexOutput meta, vectorData, quantizedVectorData, vectorIndex; private final IndexOutput meta, vectorIndex;
private final int M; private final int M;
private final int beamWidth; private final int beamWidth;
private final Lucene99ScalarQuantizedVectorsWriter quantizedVectorsWriter; private final FlatVectorsWriter flatVectorWriter;
private final int numMergeWorkers; private final int numMergeWorkers;
private final ExecutorService mergeExec; private final ExecutorService mergeExec;
@ -88,42 +76,30 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
SegmentWriteState state, SegmentWriteState state,
int M, int M,
int beamWidth, int beamWidth,
Lucene99ScalarQuantizedVectorsFormat quantizedVectorsFormat, FlatVectorsWriter flatVectorWriter,
int numMergeWorkers, int numMergeWorkers,
ExecutorService mergeExec) ExecutorService mergeExec)
throws IOException { throws IOException {
this.M = M; this.M = M;
this.flatVectorWriter = flatVectorWriter;
this.beamWidth = beamWidth; this.beamWidth = beamWidth;
this.numMergeWorkers = numMergeWorkers; this.numMergeWorkers = numMergeWorkers;
this.mergeExec = mergeExec; this.mergeExec = mergeExec;
segmentWriteState = state; segmentWriteState = state;
String metaFileName = String metaFileName =
IndexFileNames.segmentFileName( IndexFileNames.segmentFileName(
state.segmentInfo.name, state.segmentSuffix, Lucene99HnswVectorsFormat.META_EXTENSION); state.segmentInfo.name, state.segmentSuffix, Lucene99HnswVectorsFormat.META_EXTENSION);
String vectorDataFileName =
IndexFileNames.segmentFileName(
state.segmentInfo.name,
state.segmentSuffix,
Lucene99HnswVectorsFormat.VECTOR_DATA_EXTENSION);
String indexDataFileName = String indexDataFileName =
IndexFileNames.segmentFileName( IndexFileNames.segmentFileName(
state.segmentInfo.name, state.segmentInfo.name,
state.segmentSuffix, state.segmentSuffix,
Lucene99HnswVectorsFormat.VECTOR_INDEX_EXTENSION); Lucene99HnswVectorsFormat.VECTOR_INDEX_EXTENSION);
final String quantizedVectorDataFileName =
quantizedVectorsFormat != null
? IndexFileNames.segmentFileName(
state.segmentInfo.name,
state.segmentSuffix,
Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_DATA_EXTENSION)
: null;
boolean success = false; boolean success = false;
try { try {
meta = state.directory.createOutput(metaFileName, state.context); meta = state.directory.createOutput(metaFileName, state.context);
vectorData = state.directory.createOutput(vectorDataFileName, state.context);
vectorIndex = state.directory.createOutput(indexDataFileName, state.context); vectorIndex = state.directory.createOutput(indexDataFileName, state.context);
CodecUtil.writeIndexHeader( CodecUtil.writeIndexHeader(
@ -132,34 +108,12 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
Lucene99HnswVectorsFormat.VERSION_CURRENT, Lucene99HnswVectorsFormat.VERSION_CURRENT,
state.segmentInfo.getId(), state.segmentInfo.getId(),
state.segmentSuffix); state.segmentSuffix);
CodecUtil.writeIndexHeader(
vectorData,
Lucene99HnswVectorsFormat.VECTOR_DATA_CODEC_NAME,
Lucene99HnswVectorsFormat.VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix);
CodecUtil.writeIndexHeader( CodecUtil.writeIndexHeader(
vectorIndex, vectorIndex,
Lucene99HnswVectorsFormat.VECTOR_INDEX_CODEC_NAME, Lucene99HnswVectorsFormat.VECTOR_INDEX_CODEC_NAME,
Lucene99HnswVectorsFormat.VERSION_CURRENT, Lucene99HnswVectorsFormat.VERSION_CURRENT,
state.segmentInfo.getId(), state.segmentInfo.getId(),
state.segmentSuffix); state.segmentSuffix);
if (quantizedVectorDataFileName != null) {
quantizedVectorData =
state.directory.createOutput(quantizedVectorDataFileName, state.context);
CodecUtil.writeIndexHeader(
quantizedVectorData,
Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_DATA_CODEC_NAME,
Lucene99ScalarQuantizedVectorsFormat.VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix);
quantizedVectorsWriter =
new Lucene99ScalarQuantizedVectorsWriter(
quantizedVectorData, quantizedVectorsFormat.quantile);
} else {
quantizedVectorData = null;
quantizedVectorsWriter = null;
}
success = true; success = true;
} finally { } finally {
if (success == false) { if (success == false) {
@ -170,34 +124,20 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
@Override @Override
public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException { public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
Lucene99ScalarQuantizedVectorsWriter.QuantizationFieldVectorWriter quantizedVectorFieldWriter =
null;
// Quantization only supports FLOAT32 for now
if (quantizedVectorsWriter != null
&& fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
quantizedVectorFieldWriter =
quantizedVectorsWriter.addField(fieldInfo, segmentWriteState.infoStream);
}
FieldWriter<?> newField = FieldWriter<?> newField =
FieldWriter.create( FieldWriter.create(fieldInfo, M, beamWidth, segmentWriteState.infoStream);
fieldInfo, M, beamWidth, segmentWriteState.infoStream, quantizedVectorFieldWriter);
fields.add(newField); fields.add(newField);
return newField; return flatVectorWriter.addField(fieldInfo, newField);
} }
@Override @Override
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
flatVectorWriter.flush(maxDoc, sortMap);
for (FieldWriter<?> field : fields) { for (FieldWriter<?> field : fields) {
long[] quantizedVectorOffsetAndLen = null;
if (field.quantizedWriter != null) {
assert quantizedVectorsWriter != null;
quantizedVectorOffsetAndLen =
quantizedVectorsWriter.flush(sortMap, field.quantizedWriter, field.docsWithField);
}
if (sortMap == null) { if (sortMap == null) {
writeField(field, maxDoc, quantizedVectorOffsetAndLen); writeField(field);
} else { } else {
writeSortingField(field, maxDoc, sortMap, quantizedVectorOffsetAndLen); writeSortingField(field, sortMap);
} }
} }
} }
@ -208,40 +148,29 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
throw new IllegalStateException("already finished"); throw new IllegalStateException("already finished");
} }
finished = true; finished = true;
if (quantizedVectorsWriter != null) { flatVectorWriter.finish();
quantizedVectorsWriter.finish();
}
if (meta != null) { if (meta != null) {
// write end of fields marker // write end of fields marker
meta.writeInt(-1); meta.writeInt(-1);
CodecUtil.writeFooter(meta); CodecUtil.writeFooter(meta);
} }
if (vectorData != null) { if (vectorIndex != null) {
CodecUtil.writeFooter(vectorData);
CodecUtil.writeFooter(vectorIndex); CodecUtil.writeFooter(vectorIndex);
} }
} }
@Override @Override
public long ramBytesUsed() { public long ramBytesUsed() {
long total = 0; long total = SHALLOW_RAM_BYTES_USED;
total += flatVectorWriter.ramBytesUsed();
for (FieldWriter<?> field : fields) { for (FieldWriter<?> field : fields) {
total += field.ramBytesUsed(); total += field.ramBytesUsed();
} }
return total; return total;
} }
private void writeField(FieldWriter<?> fieldData, int maxDoc, long[] quantizedVecOffsetAndLen) private void writeField(FieldWriter<?> fieldData) throws IOException {
throws IOException {
// write vector values
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
switch (fieldData.fieldInfo.getVectorEncoding()) {
case BYTE -> writeByteVectors(fieldData);
case FLOAT32 -> writeFloat32Vectors(fieldData);
}
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
// write graph // write graph
long vectorIndexOffset = vectorIndex.getFilePointer(); long vectorIndexOffset = vectorIndex.getFilePointer();
OnHeapHnswGraph graph = fieldData.getGraph(); OnHeapHnswGraph graph = fieldData.getGraph();
@ -249,43 +178,15 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset; long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
writeMeta( writeMeta(
fieldData.isQuantized(),
fieldData.fieldInfo, fieldData.fieldInfo,
maxDoc,
fieldData.getConfiguredQuantile(),
fieldData.getMinQuantile(),
fieldData.getMaxQuantile(),
quantizedVecOffsetAndLen,
vectorDataOffset,
vectorDataLength,
vectorIndexOffset, vectorIndexOffset,
vectorIndexLength, vectorIndexLength,
fieldData.docsWithField, fieldData.docsWithField.cardinality(),
graph, graph,
graphLevelNodeOffsets); graphLevelNodeOffsets);
} }
private void writeFloat32Vectors(FieldWriter<?> fieldData) throws IOException { private void writeSortingField(FieldWriter<?> fieldData, Sorter.DocMap sortMap)
final ByteBuffer buffer =
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
for (Object v : fieldData.vectors) {
buffer.asFloatBuffer().put((float[]) v);
vectorData.writeBytes(buffer.array(), buffer.array().length);
}
}
private void writeByteVectors(FieldWriter<?> fieldData) throws IOException {
for (Object v : fieldData.vectors) {
byte[] vector = (byte[]) v;
vectorData.writeBytes(vector, vector.length);
}
}
private void writeSortingField(
FieldWriter<?> fieldData,
int maxDoc,
Sorter.DocMap sortMap,
long[] quantizedVectorOffsetAndLen)
throws IOException { throws IOException {
final int[] docIdOffsets = new int[sortMap.size()]; final int[] docIdOffsets = new int[sortMap.size()];
int offset = 1; // 0 means no vector for this (field, document) int offset = 1; // 0 means no vector for this (field, document)
@ -310,15 +211,6 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
} }
doc++; doc++;
} }
// write vector values
long vectorDataOffset =
switch (fieldData.fieldInfo.getVectorEncoding()) {
case BYTE -> writeSortedByteVectors(fieldData, ordMap);
case FLOAT32 -> writeSortedFloat32Vectors(fieldData, ordMap);
};
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
// write graph // write graph
long vectorIndexOffset = vectorIndex.getFilePointer(); long vectorIndexOffset = vectorIndex.getFilePointer();
OnHeapHnswGraph graph = fieldData.getGraph(); OnHeapHnswGraph graph = fieldData.getGraph();
@ -327,44 +219,14 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset; long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
writeMeta( writeMeta(
fieldData.isQuantized(),
fieldData.fieldInfo, fieldData.fieldInfo,
maxDoc,
fieldData.getConfiguredQuantile(),
fieldData.getMinQuantile(),
fieldData.getMaxQuantile(),
quantizedVectorOffsetAndLen,
vectorDataOffset,
vectorDataLength,
vectorIndexOffset, vectorIndexOffset,
vectorIndexLength, vectorIndexLength,
newDocsWithField, fieldData.docsWithField.cardinality(),
mockGraph, mockGraph,
graphLevelNodeOffsets); graphLevelNodeOffsets);
} }
private long writeSortedFloat32Vectors(FieldWriter<?> fieldData, int[] ordMap)
throws IOException {
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
final ByteBuffer buffer =
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
for (int ordinal : ordMap) {
float[] vector = (float[]) fieldData.vectors.get(ordinal);
buffer.asFloatBuffer().put(vector);
vectorData.writeBytes(buffer.array(), buffer.array().length);
}
return vectorDataOffset;
}
private long writeSortedByteVectors(FieldWriter<?> fieldData, int[] ordMap) throws IOException {
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
for (int ordinal : ordMap) {
byte[] vector = (byte[]) fieldData.vectors.get(ordinal);
vectorData.writeBytes(vector, vector.length);
}
return vectorDataOffset;
}
/** /**
* Reconstructs the graph given the old and new node ids. * Reconstructs the graph given the old and new node ids.
* *
@ -475,116 +337,10 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
@Override @Override
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); CloseableRandomVectorScorerSupplier scorerSupplier =
IndexOutput tempVectorData = null; flatVectorWriter.mergeOneFieldToIndex(fieldInfo, mergeState);
IndexInput vectorDataInput = null;
CloseableRandomVectorScorerSupplier scorerSupplier = null;
boolean success = false; boolean success = false;
try { try {
ScalarQuantizer scalarQuantizer = null;
long[] quantizedVectorDataOffsetAndLength = null;
// If we have configured quantization and are FLOAT32
if (quantizedVectorsWriter != null
&& fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
// We need the quantization parameters to write to the meta file
scalarQuantizer = quantizedVectorsWriter.mergeQuantiles(fieldInfo, mergeState);
if (segmentWriteState.infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) {
segmentWriteState.infoStream.message(
QUANTIZED_VECTOR_COMPONENT,
"Merged quantiles field: "
+ fieldInfo.name
+ " newly merged quantile: "
+ scalarQuantizer);
}
assert scalarQuantizer != null;
quantizedVectorDataOffsetAndLength = new long[2];
quantizedVectorDataOffsetAndLength[0] = quantizedVectorData.alignFilePointer(Float.BYTES);
scorerSupplier =
quantizedVectorsWriter.mergeOneField(
segmentWriteState, fieldInfo, mergeState, scalarQuantizer);
quantizedVectorDataOffsetAndLength[1] =
quantizedVectorData.getFilePointer() - quantizedVectorDataOffsetAndLength[0];
}
final DocsWithFieldSet docsWithField;
int byteSize = fieldInfo.getVectorDimension() * fieldInfo.getVectorEncoding().byteSize;
// If we extract vector storage, this could be cleaner.
// But for now, vector storage & index creation/storage live together.
if (scorerSupplier == null) {
tempVectorData =
segmentWriteState.directory.createTempOutput(
vectorData.getName(), "temp", segmentWriteState.context);
docsWithField =
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> writeByteVectorData(
tempVectorData, MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState));
case FLOAT32 -> writeVectorData(
tempVectorData, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState));
};
CodecUtil.writeFooter(tempVectorData);
IOUtils.close(tempVectorData);
// copy the temporary file vectors to the actual data file
vectorDataInput =
segmentWriteState.directory.openInput(
tempVectorData.getName(), segmentWriteState.context);
vectorData.copyBytes(vectorDataInput, vectorDataInput.length() - CodecUtil.footerLength());
CodecUtil.retrieveChecksum(vectorDataInput);
final RandomVectorScorerSupplier innerScoreSupplier =
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> RandomVectorScorerSupplier.createBytes(
new OffHeapByteVectorValues.DenseOffHeapVectorValues(
fieldInfo.getVectorDimension(),
docsWithField.cardinality(),
vectorDataInput,
byteSize),
fieldInfo.getVectorSimilarityFunction());
case FLOAT32 -> RandomVectorScorerSupplier.createFloats(
new OffHeapFloatVectorValues.DenseOffHeapVectorValues(
fieldInfo.getVectorDimension(),
docsWithField.cardinality(),
vectorDataInput,
byteSize),
fieldInfo.getVectorSimilarityFunction());
};
final String tempFileName = tempVectorData.getName();
final IndexInput finalVectorDataInput = vectorDataInput;
scorerSupplier =
new CloseableRandomVectorScorerSupplier() {
boolean closed = false;
@Override
public RandomVectorScorer scorer(int ord) throws IOException {
return innerScoreSupplier.scorer(ord);
}
@Override
public void close() throws IOException {
if (closed) {
return;
}
closed = true;
IOUtils.close(finalVectorDataInput);
segmentWriteState.directory.deleteFile(tempFileName);
}
@Override
public RandomVectorScorerSupplier copy() throws IOException {
// here we just return the inner out since we only need to close this outside copy
return innerScoreSupplier.copy();
}
};
} else {
// No need to use temporary file as we don't have to re-open for reading
docsWithField =
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> writeByteVectorData(
vectorData, MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState));
case FLOAT32 -> writeVectorData(
vectorData, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState));
};
}
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
long vectorIndexOffset = vectorIndex.getFilePointer(); long vectorIndexOffset = vectorIndex.getFilePointer();
// build the graph using the temporary vector data // build the graph using the temporary vector data
// we use Lucene99HnswVectorsReader.DenseOffHeapVectorValues for the graph construction // we use Lucene99HnswVectorsReader.DenseOffHeapVectorValues for the graph construction
@ -592,7 +348,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
// TODO: separate random access vector values from DocIdSetIterator? // TODO: separate random access vector values from DocIdSetIterator?
OnHeapHnswGraph graph = null; OnHeapHnswGraph graph = null;
int[][] vectorIndexNodeOffsets = null; int[][] vectorIndexNodeOffsets = null;
if (docsWithField.cardinality() != 0) { if (scorerSupplier.totalVectorCount() > 0) {
// build graph // build graph
HnswGraphMerger merger = createGraphMerger(fieldInfo, scorerSupplier); HnswGraphMerger merger = createGraphMerger(fieldInfo, scorerSupplier);
for (int i = 0; i < mergeState.liveDocs.length; i++) { for (int i = 0; i < mergeState.liveDocs.length; i++) {
@ -608,23 +364,17 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
} }
graph = graph =
merger.merge( merger.merge(
mergedVectorIterator, segmentWriteState.infoStream, docsWithField.cardinality()); mergedVectorIterator,
segmentWriteState.infoStream,
scorerSupplier.totalVectorCount());
vectorIndexNodeOffsets = writeGraph(graph); vectorIndexNodeOffsets = writeGraph(graph);
} }
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset; long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
writeMeta( writeMeta(
scalarQuantizer != null,
fieldInfo, fieldInfo,
segmentWriteState.segmentInfo.maxDoc(),
scalarQuantizer == null ? null : scalarQuantizer.getConfiguredQuantile(),
scalarQuantizer == null ? null : scalarQuantizer.getLowerQuantile(),
scalarQuantizer == null ? null : scalarQuantizer.getUpperQuantile(),
quantizedVectorDataOffsetAndLength,
vectorDataOffset,
vectorDataLength,
vectorIndexOffset, vectorIndexOffset,
vectorIndexLength, vectorIndexLength,
docsWithField, scorerSupplier.totalVectorCount(),
graph, graph,
vectorIndexNodeOffsets); vectorIndexNodeOffsets);
success = true; success = true;
@ -632,11 +382,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
if (success) { if (success) {
IOUtils.close(scorerSupplier); IOUtils.close(scorerSupplier);
} else { } else {
IOUtils.closeWhileHandlingException(scorerSupplier, vectorDataInput, tempVectorData); IOUtils.closeWhileHandlingException(scorerSupplier);
if (tempVectorData != null) {
IOUtils.deleteFilesIgnoringExceptions(
segmentWriteState.directory, tempVectorData.getName());
}
} }
} }
} }
@ -652,7 +398,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
int countOnLevel0 = graph.size(); int countOnLevel0 = graph.size();
int[][] offsets = new int[graph.numLevels()][]; int[][] offsets = new int[graph.numLevels()][];
for (int level = 0; level < graph.numLevels(); level++) { for (int level = 0; level < graph.numLevels(); level++) {
int[] sortedNodes = getSortedNodes(graph.getNodesOnLevel(level)); int[] sortedNodes = NodesIterator.getSortedNodes(graph.getNodesOnLevel(level));
offsets[level] = new int[sortedNodes.length]; offsets[level] = new int[sortedNodes.length];
int nodeOffsetId = 0; int nodeOffsetId = 0;
for (int node : sortedNodes) { for (int node : sortedNodes) {
@ -680,80 +426,21 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
return offsets; return offsets;
} }
public static int[] getSortedNodes(NodesIterator nodesOnLevel) {
int[] sortedNodes = new int[nodesOnLevel.size()];
for (int n = 0; nodesOnLevel.hasNext(); n++) {
sortedNodes[n] = nodesOnLevel.nextInt();
}
Arrays.sort(sortedNodes);
return sortedNodes;
}
private HnswGraphMerger createGraphMerger(
FieldInfo fieldInfo, RandomVectorScorerSupplier scorerSupplier) {
if (mergeExec != null) {
return new ConcurrentHnswMerger(
fieldInfo, scorerSupplier, M, beamWidth, mergeExec, numMergeWorkers);
}
return new IncrementalHnswGraphMerger(fieldInfo, scorerSupplier, M, beamWidth);
}
private void writeMeta( private void writeMeta(
boolean isQuantized,
FieldInfo field, FieldInfo field,
int maxDoc,
Float configuredQuantizationQuantile,
Float lowerQuantile,
Float upperQuantile,
long[] quantizedVectorDataOffsetAndLen,
long vectorDataOffset,
long vectorDataLength,
long vectorIndexOffset, long vectorIndexOffset,
long vectorIndexLength, long vectorIndexLength,
DocsWithFieldSet docsWithField, int count,
HnswGraph graph, HnswGraph graph,
int[][] graphLevelNodeOffsets) int[][] graphLevelNodeOffsets)
throws IOException { throws IOException {
meta.writeInt(field.number); meta.writeInt(field.number);
meta.writeInt(field.getVectorEncoding().ordinal()); meta.writeInt(field.getVectorEncoding().ordinal());
meta.writeInt(field.getVectorSimilarityFunction().ordinal()); meta.writeInt(field.getVectorSimilarityFunction().ordinal());
meta.writeByte(isQuantized ? (byte) 1 : (byte) 0);
if (isQuantized) {
assert lowerQuantile != null
&& upperQuantile != null
&& quantizedVectorDataOffsetAndLen != null;
assert quantizedVectorDataOffsetAndLen.length == 2;
meta.writeInt(
Float.floatToIntBits(
configuredQuantizationQuantile != null
? configuredQuantizationQuantile
: calculateDefaultQuantile(field.getVectorDimension())));
meta.writeInt(Float.floatToIntBits(lowerQuantile));
meta.writeInt(Float.floatToIntBits(upperQuantile));
meta.writeVLong(quantizedVectorDataOffsetAndLen[0]);
meta.writeVLong(quantizedVectorDataOffsetAndLen[1]);
} else {
assert configuredQuantizationQuantile == null
&& lowerQuantile == null
&& upperQuantile == null
&& quantizedVectorDataOffsetAndLen == null;
}
meta.writeVLong(vectorDataOffset);
meta.writeVLong(vectorDataLength);
meta.writeVLong(vectorIndexOffset); meta.writeVLong(vectorIndexOffset);
meta.writeVLong(vectorIndexLength); meta.writeVLong(vectorIndexLength);
meta.writeVInt(field.getVectorDimension()); meta.writeVInt(field.getVectorDimension());
// write docIDs
int count = docsWithField.cardinality();
meta.writeInt(count); meta.writeInt(count);
if (isQuantized) {
OrdToDocDISIReaderConfiguration.writeStoredMeta(
DIRECT_MONOTONIC_BLOCK_SHIFT, meta, quantizedVectorData, count, maxDoc, docsWithField);
}
OrdToDocDISIReaderConfiguration.writeStoredMeta(
DIRECT_MONOTONIC_BLOCK_SHIFT, meta, vectorData, count, maxDoc, docsWithField);
meta.writeVInt(M); meta.writeVInt(M);
// write graph nodes on each level // write graph nodes on each level
if (graph == null) { if (graph == null) {
@ -799,109 +486,47 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
} }
} }
/** private HnswGraphMerger createGraphMerger(
* Writes the byte vector values to the output and returns a set of documents that contains FieldInfo fieldInfo, RandomVectorScorerSupplier scorerSupplier) {
* vectors. if (mergeExec != null) {
*/ return new ConcurrentHnswMerger(
private static DocsWithFieldSet writeByteVectorData( fieldInfo, scorerSupplier, M, beamWidth, mergeExec, numMergeWorkers);
IndexOutput output, ByteVectorValues byteVectorValues) throws IOException {
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
for (int docV = byteVectorValues.nextDoc();
docV != NO_MORE_DOCS;
docV = byteVectorValues.nextDoc()) {
// write vector
byte[] binaryValue = byteVectorValues.vectorValue();
assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize;
output.writeBytes(binaryValue, binaryValue.length);
docsWithField.add(docV);
} }
return docsWithField; return new IncrementalHnswGraphMerger(fieldInfo, scorerSupplier, M, beamWidth);
}
/**
* Writes the vector values to the output and returns a set of documents that contains vectors.
*/
private static DocsWithFieldSet writeVectorData(
IndexOutput output, FloatVectorValues floatVectorValues) throws IOException {
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
ByteBuffer buffer =
ByteBuffer.allocate(floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize)
.order(ByteOrder.LITTLE_ENDIAN);
for (int docV = floatVectorValues.nextDoc();
docV != NO_MORE_DOCS;
docV = floatVectorValues.nextDoc()) {
// write vector
float[] value = floatVectorValues.vectorValue();
buffer.asFloatBuffer().put(value);
output.writeBytes(buffer.array(), buffer.limit());
docsWithField.add(docV);
}
return docsWithField;
} }
@Override @Override
public void close() throws IOException { public void close() throws IOException {
IOUtils.close(meta, vectorData, vectorIndex, quantizedVectorData); IOUtils.close(meta, vectorIndex, flatVectorWriter);
} }
private abstract static class FieldWriter<T> extends KnnFieldVectorsWriter<T> { private static class FieldWriter<T> extends KnnFieldVectorsWriter<T> {
private static final long SHALLOW_SIZE =
RamUsageEstimator.shallowSizeOfInstance(FieldWriter.class);
private final FieldInfo fieldInfo; private final FieldInfo fieldInfo;
private final int dim;
private final DocsWithFieldSet docsWithField; private final DocsWithFieldSet docsWithField;
private final List<T> vectors; private final List<T> vectors;
private final HnswGraphBuilder hnswGraphBuilder; private final HnswGraphBuilder hnswGraphBuilder;
private final Lucene99ScalarQuantizedVectorsWriter.QuantizationFieldVectorWriter
quantizedWriter;
private int lastDocID = -1; private int lastDocID = -1;
private int node = 0; private int node = 0;
static FieldWriter<?> create( static FieldWriter<?> create(FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream)
FieldInfo fieldInfo,
int M,
int beamWidth,
InfoStream infoStream,
Lucene99ScalarQuantizedVectorsWriter.QuantizationFieldVectorWriter writer)
throws IOException { throws IOException {
int dim = fieldInfo.getVectorDimension();
return switch (fieldInfo.getVectorEncoding()) { return switch (fieldInfo.getVectorEncoding()) {
case BYTE -> new FieldWriter<byte[]>(fieldInfo, M, beamWidth, infoStream, writer) { case BYTE -> new FieldWriter<byte[]>(fieldInfo, M, beamWidth, infoStream);
@Override case FLOAT32 -> new FieldWriter<float[]>(fieldInfo, M, beamWidth, infoStream);
public byte[] copyValue(byte[] value) {
return ArrayUtil.copyOfSubArray(value, 0, dim);
}
};
case FLOAT32 -> new FieldWriter<float[]>(fieldInfo, M, beamWidth, infoStream, writer) {
@Override
public float[] copyValue(float[] value) {
return ArrayUtil.copyOfSubArray(value, 0, dim);
}
};
}; };
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
FieldWriter( FieldWriter(FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream)
FieldInfo fieldInfo,
int M,
int beamWidth,
InfoStream infoStream,
Lucene99ScalarQuantizedVectorsWriter.QuantizationFieldVectorWriter quantizedWriter)
throws IOException { throws IOException {
this.fieldInfo = fieldInfo; this.fieldInfo = fieldInfo;
this.dim = fieldInfo.getVectorDimension();
this.docsWithField = new DocsWithFieldSet(); this.docsWithField = new DocsWithFieldSet();
this.quantizedWriter = quantizedWriter;
vectors = new ArrayList<>(); vectors = new ArrayList<>();
if (quantizedWriter != null RAVectorValues<T> raVectors = new RAVectorValues<>(vectors, fieldInfo.getVectorDimension());
&& fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32) == false) {
throw new IllegalArgumentException(
"Vector encoding ["
+ VectorEncoding.FLOAT32
+ "] required for quantized vectors; provided="
+ fieldInfo.getVectorEncoding());
}
RAVectorValues<T> raVectors = new RAVectorValues<>(vectors, dim);
RandomVectorScorerSupplier scorerSupplier = RandomVectorScorerSupplier scorerSupplier =
switch (fieldInfo.getVectorEncoding()) { switch (fieldInfo.getVectorEncoding()) {
case BYTE -> RandomVectorScorerSupplier.createBytes( case BYTE -> RandomVectorScorerSupplier.createBytes(
@ -925,20 +550,20 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
+ "\" appears more than once in this document (only one value is allowed per field)"); + "\" appears more than once in this document (only one value is allowed per field)");
} }
assert docID > lastDocID; assert docID > lastDocID;
T copy = copyValue(vectorValue); vectors.add(vectorValue);
if (quantizedWriter != null) {
assert vectorValue instanceof float[];
quantizedWriter.addValue((float[]) copy);
}
docsWithField.add(docID); docsWithField.add(docID);
vectors.add(copy);
hnswGraphBuilder.addGraphNode(node); hnswGraphBuilder.addGraphNode(node);
node++; node++;
lastDocID = docID; lastDocID = docID;
} }
@Override
public T copyValue(T vectorValue) {
throw new UnsupportedOperationException();
}
OnHeapHnswGraph getGraph() { OnHeapHnswGraph getGraph() {
if (vectors.size() > 0) { if (node > 0) {
return hnswGraphBuilder.getGraph(); return hnswGraphBuilder.getGraph();
} else { } else {
return null; return null;
@ -947,32 +572,11 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
@Override @Override
public long ramBytesUsed() { public long ramBytesUsed() {
if (vectors.size() == 0) return 0; return SHALLOW_SIZE
long quantizationSpace = quantizedWriter != null ? quantizedWriter.ramBytesUsed() : 0L; + docsWithField.ramBytesUsed()
return docsWithField.ramBytesUsed()
+ (long) vectors.size() + (long) vectors.size()
* (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER) * (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
+ (long) vectors.size() + hnswGraphBuilder.getGraph().ramBytesUsed();
* fieldInfo.getVectorDimension()
* fieldInfo.getVectorEncoding().byteSize
+ hnswGraphBuilder.getGraph().ramBytesUsed()
+ quantizationSpace;
}
Float getConfiguredQuantile() {
return quantizedWriter == null ? null : quantizedWriter.getQuantile();
}
Float getMinQuantile() {
return quantizedWriter == null ? null : quantizedWriter.getMinQuantile();
}
Float getMaxQuantile() {
return quantizedWriter == null ? null : quantizedWriter.getMaxQuantile();
}
boolean isQuantized() {
return quantizedWriter != null;
} }
} }

View File

@ -17,20 +17,31 @@
package org.apache.lucene.codecs.lucene99; package org.apache.lucene.codecs.lucene99;
import java.io.IOException;
import org.apache.lucene.codecs.FlatVectorsFormat;
import org.apache.lucene.codecs.FlatVectorsReader;
import org.apache.lucene.codecs.FlatVectorsWriter;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
/** /**
* Format supporting vector quantization, storage, and retrieval * Format supporting vector quantization, storage, and retrieval
* *
* @lucene.experimental * @lucene.experimental
*/ */
public final class Lucene99ScalarQuantizedVectorsFormat { public final class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
public static final String QUANTIZED_VECTOR_COMPONENT = "QVEC"; public static final String QUANTIZED_VECTOR_COMPONENT = "QVEC";
static final String NAME = "Lucene99ScalarQuantizedVectorsFormat"; static final String NAME = "Lucene99ScalarQuantizedVectorsFormat";
static final int VERSION_START = 0; static final int VERSION_START = 0;
static final int VERSION_CURRENT = VERSION_START; static final int VERSION_CURRENT = VERSION_START;
static final String QUANTIZED_VECTOR_DATA_CODEC_NAME = "Lucene99ScalarQuantizedVectorsData"; static final String META_CODEC_NAME = "Lucene99ScalarQuantizedVectorsFormatMeta";
static final String QUANTIZED_VECTOR_DATA_EXTENSION = "veq"; static final String VECTOR_DATA_CODEC_NAME = "Lucene99ScalarQuantizedVectorsFormatData";
static final String META_EXTENSION = "vemq";
static final String VECTOR_DATA_EXTENSION = "veq";
private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat();
/** The minimum quantile */ /** The minimum quantile */
private static final float MINIMUM_QUANTILE = 0.9f; private static final float MINIMUM_QUANTILE = 0.9f;
@ -74,6 +85,24 @@ public final class Lucene99ScalarQuantizedVectorsFormat {
@Override @Override
public String toString() { public String toString() {
return NAME + "(name=" + NAME + ", quantile=" + quantile + ")"; return NAME
+ "(name="
+ NAME
+ ", quantile="
+ quantile
+ ", rawVectorFormat="
+ rawVectorFormat
+ ")";
}
@Override
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99ScalarQuantizedVectorsWriter(
state, quantile, rawVectorFormat.fieldsWriter(state));
}
@Override
public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new Lucene99ScalarQuantizedVectorsReader(state, rawVectorFormat.fieldsReader(state));
} }
} }

View File

@ -18,46 +18,128 @@
package org.apache.lucene.codecs.lucene99; package org.apache.lucene.codecs.lucene99;
import java.io.IOException; import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.FlatVectorsReader;
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.ScalarQuantizer;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
/** /**
* Reads Scalar Quantized vectors from the index segments along with index data structures. * Reads Scalar Quantized vectors from the index segments along with index data structures.
* *
* @lucene.experimental * @lucene.experimental
*/ */
public final class Lucene99ScalarQuantizedVectorsReader { public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReader
implements QuantizedVectorsReader {
private static final long SHALLOW_SIZE =
RamUsageEstimator.shallowSizeOfInstance(Lucene99ScalarQuantizedVectorsReader.class);
private final Map<String, FieldEntry> fields = new HashMap<>();
private final IndexInput quantizedVectorData; private final IndexInput quantizedVectorData;
private final FlatVectorsReader rawVectorsReader;
Lucene99ScalarQuantizedVectorsReader(IndexInput quantizedVectorData) { Lucene99ScalarQuantizedVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader)
this.quantizedVectorData = quantizedVectorData; throws IOException {
int versionMeta = -1;
String metaFileName =
IndexFileNames.segmentFileName(
state.segmentInfo.name,
state.segmentSuffix,
Lucene99ScalarQuantizedVectorsFormat.META_EXTENSION);
boolean success = false;
try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) {
Throwable priorE = null;
try {
versionMeta =
CodecUtil.checkIndexHeader(
meta,
Lucene99ScalarQuantizedVectorsFormat.META_CODEC_NAME,
Lucene99ScalarQuantizedVectorsFormat.VERSION_START,
Lucene99ScalarQuantizedVectorsFormat.VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix);
readFields(meta, state.fieldInfos);
} catch (Throwable exception) {
priorE = exception;
} finally {
try {
CodecUtil.checkFooter(meta, priorE);
success = true;
} finally {
if (success == false) {
IOUtils.close(rawVectorsReader);
}
}
}
}
success = false;
this.rawVectorsReader = rawVectorsReader;
try {
quantizedVectorData =
openDataInput(
state,
versionMeta,
Lucene99ScalarQuantizedVectorsFormat.VECTOR_DATA_EXTENSION,
Lucene99ScalarQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME);
success = true;
} finally {
if (success == false) {
IOUtils.closeWhileHandlingException(this);
}
}
} }
static void validateFieldEntry( private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException {
FieldInfo info, int fieldDimension, int size, long quantizedVectorDataLength) { for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) {
FieldInfo info = infos.fieldInfo(fieldNumber);
if (info == null) {
throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
}
FieldEntry fieldEntry = readField(meta);
validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry);
}
}
static void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) {
int dimension = info.getVectorDimension(); int dimension = info.getVectorDimension();
if (dimension != fieldDimension) { if (dimension != fieldEntry.dimension) {
throw new IllegalStateException( throw new IllegalStateException(
"Inconsistent vector dimension for field=\"" "Inconsistent vector dimension for field=\""
+ info.name + info.name
+ "\"; " + "\"; "
+ dimension + dimension
+ " != " + " != "
+ fieldDimension); + fieldEntry.dimension);
} }
// int8 quantized and calculated stored offset. // int8 quantized and calculated stored offset.
long quantizedVectorBytes = dimension + Float.BYTES; long quantizedVectorBytes = dimension + Float.BYTES;
long numQuantizedVectorBytes = Math.multiplyExact(quantizedVectorBytes, size); long numQuantizedVectorBytes = Math.multiplyExact(quantizedVectorBytes, fieldEntry.size);
if (numQuantizedVectorBytes != quantizedVectorDataLength) { if (numQuantizedVectorBytes != fieldEntry.vectorDataLength) {
throw new IllegalStateException( throw new IllegalStateException(
"Quantized vector data length " "Quantized vector data length "
+ quantizedVectorDataLength + fieldEntry.vectorDataLength
+ " not matching size=" + " not matching size="
+ size + fieldEntry.size
+ " * (dim=" + " * (dim="
+ dimension + dimension
+ " + 4)" + " + 4)"
@ -66,23 +148,184 @@ public final class Lucene99ScalarQuantizedVectorsReader {
} }
} }
@Override
public void checkIntegrity() throws IOException { public void checkIntegrity() throws IOException {
rawVectorsReader.checkIntegrity();
CodecUtil.checksumEntireFile(quantizedVectorData); CodecUtil.checksumEntireFile(quantizedVectorData);
} }
OffHeapQuantizedByteVectorValues getQuantizedVectorValues( @Override
OrdToDocDISIReaderConfiguration configuration, public FloatVectorValues getFloatVectorValues(String field) throws IOException {
int dimension, return rawVectorsReader.getFloatVectorValues(field);
int size, }
long quantizedVectorDataOffset,
long quantizedVectorDataLength) @Override
public ByteVectorValues getByteVectorValues(String field) throws IOException {
return rawVectorsReader.getByteVectorValues(field);
}
private static IndexInput openDataInput(
SegmentReadState state, int versionMeta, String fileExtension, String codecName)
throws IOException { throws IOException {
String fileName =
IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension);
IndexInput in = state.directory.openInput(fileName, state.context);
boolean success = false;
try {
int versionVectorData =
CodecUtil.checkIndexHeader(
in,
codecName,
Lucene99ScalarQuantizedVectorsFormat.VERSION_START,
Lucene99ScalarQuantizedVectorsFormat.VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix);
if (versionMeta != versionVectorData) {
throw new CorruptIndexException(
"Format versions mismatch: meta="
+ versionMeta
+ ", "
+ codecName
+ "="
+ versionVectorData,
in);
}
CodecUtil.retrieveChecksum(in);
success = true;
return in;
} finally {
if (success == false) {
IOUtils.closeWhileHandlingException(in);
}
}
}
@Override
public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return null;
}
if (fieldEntry.scalarQuantizer == null) {
return rawVectorsReader.getRandomVectorScorer(field, target);
}
OffHeapQuantizedByteVectorValues vectorValues =
OffHeapQuantizedByteVectorValues.load(
fieldEntry.ordToDoc,
fieldEntry.dimension,
fieldEntry.size,
fieldEntry.vectorDataOffset,
fieldEntry.vectorDataLength,
quantizedVectorData);
return new ScalarQuantizedRandomVectorScorer(
fieldEntry.similarityFunction, fieldEntry.scalarQuantizer, vectorValues, target);
}
@Override
public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException {
return rawVectorsReader.getRandomVectorScorer(field, target);
}
@Override
public void close() throws IOException {
IOUtils.close(quantizedVectorData, rawVectorsReader);
}
@Override
public long ramBytesUsed() {
long size = SHALLOW_SIZE;
size +=
RamUsageEstimator.sizeOfMap(
fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class));
size += rawVectorsReader.ramBytesUsed();
return size;
}
private FieldEntry readField(IndexInput input) throws IOException {
VectorEncoding vectorEncoding = readVectorEncoding(input);
VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
return new FieldEntry(input, vectorEncoding, similarityFunction);
}
private VectorSimilarityFunction readSimilarityFunction(DataInput input) throws IOException {
int similarityFunctionId = input.readInt();
if (similarityFunctionId < 0
|| similarityFunctionId >= VectorSimilarityFunction.values().length) {
throw new CorruptIndexException(
"Invalid similarity function id: " + similarityFunctionId, input);
}
return VectorSimilarityFunction.values()[similarityFunctionId];
}
private VectorEncoding readVectorEncoding(DataInput input) throws IOException {
int encodingId = input.readInt();
if (encodingId < 0 || encodingId >= VectorEncoding.values().length) {
throw new CorruptIndexException("Invalid vector encoding id: " + encodingId, input);
}
return VectorEncoding.values()[encodingId];
}
@Override
public QuantizedByteVectorValues getQuantizedVectorValues(String fieldName) throws IOException {
FieldEntry fieldEntry = fields.get(fieldName);
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return null;
}
return OffHeapQuantizedByteVectorValues.load( return OffHeapQuantizedByteVectorValues.load(
configuration, fieldEntry.ordToDoc,
dimension, fieldEntry.dimension,
size, fieldEntry.size,
quantizedVectorDataOffset, fieldEntry.vectorDataOffset,
quantizedVectorDataLength, fieldEntry.vectorDataLength,
quantizedVectorData); quantizedVectorData);
} }
@Override
public ScalarQuantizer getQuantizationState(String fieldName) {
FieldEntry fieldEntry = fields.get(fieldName);
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return null;
}
return fieldEntry.scalarQuantizer;
}
private static class FieldEntry implements Accountable {
private static final long SHALLOW_SIZE =
RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class);
final VectorSimilarityFunction similarityFunction;
final VectorEncoding vectorEncoding;
final int dimension;
final long vectorDataOffset;
final long vectorDataLength;
final ScalarQuantizer scalarQuantizer;
final int size;
final OrdToDocDISIReaderConfiguration ordToDoc;
FieldEntry(
IndexInput input,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction)
throws IOException {
this.similarityFunction = similarityFunction;
this.vectorEncoding = vectorEncoding;
vectorDataOffset = input.readVLong();
vectorDataLength = input.readVLong();
dimension = input.readVInt();
size = input.readInt();
if (size > 0) {
float configuredQuantile = Float.intBitsToFloat(input.readInt());
float minQuantile = Float.intBitsToFloat(input.readInt());
float maxQuantile = Float.intBitsToFloat(input.readInt());
scalarQuantizer = new ScalarQuantizer(minQuantile, maxQuantile, configuredQuantile);
} else {
scalarQuantizer = null;
}
ordToDoc = OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size);
}
@Override
public long ramBytesUsed() {
return SHALLOW_SIZE + RamUsageEstimator.sizeOf(ordToDoc);
}
}
} }

View File

@ -17,6 +17,7 @@
package org.apache.lucene.codecs.lucene99; package org.apache.lucene.codecs.lucene99;
import static org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT;
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_COMPONENT; import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_COMPONENT;
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile; import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
@ -29,13 +30,18 @@ import java.nio.ByteOrder;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.FlatFieldVectorsWriter;
import org.apache.lucene.codecs.FlatVectorsWriter;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.index.DocIDMerger; import org.apache.lucene.index.DocIDMerger;
import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.MergeState; import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter; import org.apache.lucene.index.Sorter;
@ -44,7 +50,6 @@ import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput; import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.InfoStream; import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.RamUsageEstimator; import org.apache.lucene.util.RamUsageEstimator;
@ -59,9 +64,9 @@ import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
* *
* @lucene.experimental * @lucene.experimental
*/ */
public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable { public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWriter {
private static final long BASE_RAM_BYTES_USED = private static final long SHALLOW_RAM_BYTES_USED =
shallowSizeOfInstance(Lucene99ScalarQuantizedVectorsWriter.class); shallowSizeOfInstance(Lucene99ScalarQuantizedVectorsWriter.class);
// Used for determining when merged quantiles shifted too far from individual segment quantiles. // Used for determining when merged quantiles shifted too far from individual segment quantiles.
@ -82,67 +87,214 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
// the quantization error) and the condition is sensitive enough to detect all adversarial cases, // the quantization error) and the condition is sensitive enough to detect all adversarial cases,
// such as merging clustered data. // such as merging clustered data.
private static final float REQUANTIZATION_LIMIT = 0.2f; private static final float REQUANTIZATION_LIMIT = 0.2f;
private final IndexOutput quantizedVectorData; private final SegmentWriteState segmentWriteState;
private final List<FieldWriter> fields = new ArrayList<>();
private final IndexOutput meta, quantizedVectorData;
private final Float quantile; private final Float quantile;
private final FlatVectorsWriter rawVectorDelegate;
private boolean finished; private boolean finished;
Lucene99ScalarQuantizedVectorsWriter(IndexOutput quantizedVectorData, Float quantile) { Lucene99ScalarQuantizedVectorsWriter(
SegmentWriteState state, Float quantile, FlatVectorsWriter rawVectorDelegate)
throws IOException {
this.quantile = quantile; this.quantile = quantile;
this.quantizedVectorData = quantizedVectorData; segmentWriteState = state;
String metaFileName =
IndexFileNames.segmentFileName(
state.segmentInfo.name,
state.segmentSuffix,
Lucene99ScalarQuantizedVectorsFormat.META_EXTENSION);
String quantizedVectorDataFileName =
IndexFileNames.segmentFileName(
state.segmentInfo.name,
state.segmentSuffix,
Lucene99ScalarQuantizedVectorsFormat.VECTOR_DATA_EXTENSION);
this.rawVectorDelegate = rawVectorDelegate;
boolean success = false;
try {
meta = state.directory.createOutput(metaFileName, state.context);
quantizedVectorData =
state.directory.createOutput(quantizedVectorDataFileName, state.context);
CodecUtil.writeIndexHeader(
meta,
Lucene99ScalarQuantizedVectorsFormat.META_CODEC_NAME,
Lucene99ScalarQuantizedVectorsFormat.VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix);
CodecUtil.writeIndexHeader(
quantizedVectorData,
Lucene99ScalarQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME,
Lucene99ScalarQuantizedVectorsFormat.VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix);
success = true;
} finally {
if (success == false) {
IOUtils.closeWhileHandlingException(this);
}
}
} }
QuantizationFieldVectorWriter addField(FieldInfo fieldInfo, InfoStream infoStream) { @Override
if (fieldInfo.getVectorEncoding() != VectorEncoding.FLOAT32) { public FlatFieldVectorsWriter<?> addField(
throw new IllegalArgumentException( FieldInfo fieldInfo, KnnFieldVectorsWriter<?> indexWriter) throws IOException {
"Only float32 vector fields are supported for quantization"); if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
}
float quantile = float quantile =
this.quantile == null this.quantile == null
? calculateDefaultQuantile(fieldInfo.getVectorDimension()) ? calculateDefaultQuantile(fieldInfo.getVectorDimension())
: this.quantile; : this.quantile;
if (infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) { FieldWriter quantizedWriter =
infoStream.message( new FieldWriter(quantile, fieldInfo, segmentWriteState.infoStream, indexWriter);
QUANTIZED_VECTOR_COMPONENT, fields.add(quantizedWriter);
"quantizing field=" indexWriter = quantizedWriter;
+ fieldInfo.name
+ " dimension="
+ fieldInfo.getVectorDimension()
+ " quantile="
+ quantile);
} }
return QuantizationFieldVectorWriter.create(fieldInfo, quantile, infoStream); return rawVectorDelegate.addField(fieldInfo, indexWriter);
} }
long[] flush( @Override
Sorter.DocMap sortMap, QuantizationFieldVectorWriter field, DocsWithFieldSet docsWithField) public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
throws IOException { rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
// Since we know we will not be searching for additional indexing, we can just write the
// the vectors directly to the new segment.
// No need to use temporary file as we don't have to re-open for reading
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
ScalarQuantizer mergedQuantizationState = mergeQuantiles(fieldInfo, mergeState);
MergedQuantizedVectorValues byteVectorValues =
MergedQuantizedVectorValues.mergeQuantizedByteVectorValues(
fieldInfo, mergeState, mergedQuantizationState);
long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES);
DocsWithFieldSet docsWithField =
writeQuantizedVectorData(quantizedVectorData, byteVectorValues);
long vectorDataLength = quantizedVectorData.getFilePointer() - vectorDataOffset;
float quantile =
this.quantile == null
? calculateDefaultQuantile(fieldInfo.getVectorDimension())
: this.quantile;
writeMeta(
fieldInfo,
segmentWriteState.segmentInfo.maxDoc(),
vectorDataOffset,
vectorDataLength,
quantile,
mergedQuantizationState.getLowerQuantile(),
mergedQuantizationState.getUpperQuantile(),
docsWithField);
}
}
@Override
public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(
FieldInfo fieldInfo, MergeState mergeState) throws IOException {
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
// Simply merge the underlying delegate, which just copies the raw vector data to a new
// segment file
rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
ScalarQuantizer mergedQuantizationState = mergeQuantiles(fieldInfo, mergeState);
return mergeOneFieldToIndex(
segmentWriteState, fieldInfo, mergeState, mergedQuantizationState);
}
// We only merge the delegate, since the field type isn't float32, quantization wasn't
// supported, so bypass it.
return rawVectorDelegate.mergeOneFieldToIndex(fieldInfo, mergeState);
}
@Override
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
rawVectorDelegate.flush(maxDoc, sortMap);
for (FieldWriter field : fields) {
field.finish(); field.finish();
return sortMap == null ? writeField(field) : writeSortingField(field, sortMap, docsWithField); if (sortMap == null) {
writeField(field, maxDoc);
} else {
writeSortingField(field, maxDoc, sortMap);
}
}
} }
void finish() throws IOException { @Override
public void finish() throws IOException {
if (finished) { if (finished) {
throw new IllegalStateException("already finished"); throw new IllegalStateException("already finished");
} }
finished = true; finished = true;
rawVectorDelegate.finish();
if (meta != null) {
// write end of fields marker
meta.writeInt(-1);
CodecUtil.writeFooter(meta);
}
if (quantizedVectorData != null) { if (quantizedVectorData != null) {
CodecUtil.writeFooter(quantizedVectorData); CodecUtil.writeFooter(quantizedVectorData);
} }
} }
private long[] writeField(QuantizationFieldVectorWriter fieldData) throws IOException { @Override
long quantizedVectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES); public long ramBytesUsed() {
writeQuantizedVectors(fieldData); long total = SHALLOW_RAM_BYTES_USED;
long quantizedVectorDataLength = for (FieldWriter field : fields) {
quantizedVectorData.getFilePointer() - quantizedVectorDataOffset; total += field.ramBytesUsed();
return new long[] {quantizedVectorDataOffset, quantizedVectorDataLength}; }
return total;
} }
private void writeQuantizedVectors(QuantizationFieldVectorWriter fieldData) throws IOException { private void writeField(FieldWriter fieldData, int maxDoc) throws IOException {
// write vector values
long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES);
writeQuantizedVectors(fieldData);
long vectorDataLength = quantizedVectorData.getFilePointer() - vectorDataOffset;
writeMeta(
fieldData.fieldInfo,
maxDoc,
vectorDataOffset,
vectorDataLength,
quantile,
fieldData.minQuantile,
fieldData.maxQuantile,
fieldData.docsWithField);
}
private void writeMeta(
FieldInfo field,
int maxDoc,
long vectorDataOffset,
long vectorDataLength,
Float configuredQuantizationQuantile,
Float lowerQuantile,
Float upperQuantile,
DocsWithFieldSet docsWithField)
throws IOException {
meta.writeInt(field.number);
meta.writeInt(field.getVectorEncoding().ordinal());
meta.writeInt(field.getVectorSimilarityFunction().ordinal());
meta.writeVLong(vectorDataOffset);
meta.writeVLong(vectorDataLength);
meta.writeVInt(field.getVectorDimension());
int count = docsWithField.cardinality();
meta.writeInt(count);
if (count > 0) {
assert Float.isFinite(lowerQuantile) && Float.isFinite(upperQuantile);
meta.writeInt(
Float.floatToIntBits(
configuredQuantizationQuantile != null
? configuredQuantizationQuantile
: calculateDefaultQuantile(field.getVectorDimension())));
meta.writeInt(Float.floatToIntBits(lowerQuantile));
meta.writeInt(Float.floatToIntBits(upperQuantile));
}
// write docIDs
OrdToDocDISIReaderConfiguration.writeStoredMeta(
DIRECT_MONOTONIC_BLOCK_SHIFT, meta, quantizedVectorData, count, maxDoc, docsWithField);
}
private void writeQuantizedVectors(FieldWriter fieldData) throws IOException {
ScalarQuantizer scalarQuantizer = fieldData.createQuantizer(); ScalarQuantizer scalarQuantizer = fieldData.createQuantizer();
byte[] vector = new byte[fieldData.dim]; byte[] vector = new byte[fieldData.fieldInfo.getVectorDimension()];
final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
float[] copy = fieldData.normalize ? new float[fieldData.dim] : null; float[] copy = fieldData.normalize ? new float[fieldData.fieldInfo.getVectorDimension()] : null;
for (float[] v : fieldData.floatVectors) { for (float[] v : fieldData.floatVectors) {
if (fieldData.normalize) { if (fieldData.normalize) {
System.arraycopy(v, 0, copy, 0, copy.length); System.arraycopy(v, 0, copy, 0, copy.length);
@ -151,7 +303,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
} }
float offsetCorrection = float offsetCorrection =
scalarQuantizer.quantize(v, vector, fieldData.vectorSimilarityFunction); scalarQuantizer.quantize(v, vector, fieldData.fieldInfo.getVectorSimilarityFunction());
quantizedVectorData.writeBytes(vector, vector.length); quantizedVectorData.writeBytes(vector, vector.length);
offsetBuffer.putFloat(offsetCorrection); offsetBuffer.putFloat(offsetCorrection);
quantizedVectorData.writeBytes(offsetBuffer.array(), offsetBuffer.array().length); quantizedVectorData.writeBytes(offsetBuffer.array(), offsetBuffer.array().length);
@ -159,14 +311,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
} }
} }
private long[] writeSortingField( private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocMap sortMap)
QuantizationFieldVectorWriter fieldData,
Sorter.DocMap sortMap,
DocsWithFieldSet docsWithField)
throws IOException { throws IOException {
final int[] docIdOffsets = new int[sortMap.size()]; final int[] docIdOffsets = new int[sortMap.size()];
int offset = 1; // 0 means no vector for this (field, document) int offset = 1; // 0 means no vector for this (field, document)
DocIdSetIterator iterator = docsWithField.iterator(); DocIdSetIterator iterator = fieldData.docsWithField.iterator();
for (int docID = iterator.nextDoc(); for (int docID = iterator.nextDoc();
docID != DocIdSetIterator.NO_MORE_DOCS; docID != DocIdSetIterator.NO_MORE_DOCS;
docID = iterator.nextDoc()) { docID = iterator.nextDoc()) {
@ -175,13 +324,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
} }
DocsWithFieldSet newDocsWithField = new DocsWithFieldSet(); DocsWithFieldSet newDocsWithField = new DocsWithFieldSet();
final int[] ordMap = new int[offset - 1]; // new ord to old ord final int[] ordMap = new int[offset - 1]; // new ord to old ord
final int[] oldOrdMap = new int[offset - 1]; // old ord to new ord
int ord = 0; int ord = 0;
int doc = 0; int doc = 0;
for (int docIdOffset : docIdOffsets) { for (int docIdOffset : docIdOffsets) {
if (docIdOffset != 0) { if (docIdOffset != 0) {
ordMap[ord] = docIdOffset - 1; ordMap[ord] = docIdOffset - 1;
oldOrdMap[docIdOffset - 1] = ord;
newDocsWithField.add(doc); newDocsWithField.add(doc);
ord++; ord++;
} }
@ -192,16 +339,22 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES); long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES);
writeSortedQuantizedVectors(fieldData, ordMap); writeSortedQuantizedVectors(fieldData, ordMap);
long quantizedVectorLength = quantizedVectorData.getFilePointer() - vectorDataOffset; long quantizedVectorLength = quantizedVectorData.getFilePointer() - vectorDataOffset;
writeMeta(
return new long[] {vectorDataOffset, quantizedVectorLength}; fieldData.fieldInfo,
maxDoc,
vectorDataOffset,
quantizedVectorLength,
quantile,
fieldData.minQuantile,
fieldData.maxQuantile,
newDocsWithField);
} }
void writeSortedQuantizedVectors(QuantizationFieldVectorWriter fieldData, int[] ordMap) private void writeSortedQuantizedVectors(FieldWriter fieldData, int[] ordMap) throws IOException {
throws IOException {
ScalarQuantizer scalarQuantizer = fieldData.createQuantizer(); ScalarQuantizer scalarQuantizer = fieldData.createQuantizer();
byte[] vector = new byte[fieldData.dim]; byte[] vector = new byte[fieldData.fieldInfo.getVectorDimension()];
final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
float[] copy = fieldData.normalize ? new float[fieldData.dim] : null; float[] copy = fieldData.normalize ? new float[fieldData.fieldInfo.getVectorDimension()] : null;
for (int ordinal : ordMap) { for (int ordinal : ordMap) {
float[] v = fieldData.floatVectors.get(ordinal); float[] v = fieldData.floatVectors.get(ordinal);
if (fieldData.normalize) { if (fieldData.normalize) {
@ -209,9 +362,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
VectorUtil.l2normalize(copy); VectorUtil.l2normalize(copy);
v = copy; v = copy;
} }
float offsetCorrection = float offsetCorrection =
scalarQuantizer.quantize(v, vector, fieldData.vectorSimilarityFunction); scalarQuantizer.quantize(v, vector, fieldData.fieldInfo.getVectorSimilarityFunction());
quantizedVectorData.writeBytes(vector, vector.length); quantizedVectorData.writeBytes(vector, vector.length);
offsetBuffer.putFloat(offsetCorrection); offsetBuffer.putFloat(offsetCorrection);
quantizedVectorData.writeBytes(offsetBuffer.array(), offsetBuffer.array().length); quantizedVectorData.writeBytes(offsetBuffer.array(), offsetBuffer.array().length);
@ -219,10 +371,9 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
} }
} }
ScalarQuantizer mergeQuantiles(FieldInfo fieldInfo, MergeState mergeState) throws IOException { private ScalarQuantizer mergeQuantiles(FieldInfo fieldInfo, MergeState mergeState)
if (fieldInfo.getVectorEncoding() != VectorEncoding.FLOAT32) { throws IOException {
return null; assert fieldInfo.getVectorEncoding() == VectorEncoding.FLOAT32;
}
float quantile = float quantile =
this.quantile == null this.quantile == null
? calculateDefaultQuantile(fieldInfo.getVectorDimension()) ? calculateDefaultQuantile(fieldInfo.getVectorDimension())
@ -230,15 +381,13 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
return mergeAndRecalculateQuantiles(mergeState, fieldInfo, quantile); return mergeAndRecalculateQuantiles(mergeState, fieldInfo, quantile);
} }
ScalarQuantizedCloseableRandomVectorScorerSupplier mergeOneField( private ScalarQuantizedCloseableRandomVectorScorerSupplier mergeOneFieldToIndex(
SegmentWriteState segmentWriteState, SegmentWriteState segmentWriteState,
FieldInfo fieldInfo, FieldInfo fieldInfo,
MergeState mergeState, MergeState mergeState,
ScalarQuantizer mergedQuantizationState) ScalarQuantizer mergedQuantizationState)
throws IOException { throws IOException {
if (fieldInfo.getVectorEncoding() != VectorEncoding.FLOAT32) { long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES);
return null;
}
IndexOutput tempQuantizedVectorData = IndexOutput tempQuantizedVectorData =
segmentWriteState.directory.createTempOutput( segmentWriteState.directory.createTempOutput(
quantizedVectorData.getName(), "temp", segmentWriteState.context); quantizedVectorData.getName(), "temp", segmentWriteState.context);
@ -257,7 +406,21 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
tempQuantizedVectorData.getName(), segmentWriteState.context); tempQuantizedVectorData.getName(), segmentWriteState.context);
quantizedVectorData.copyBytes( quantizedVectorData.copyBytes(
quantizationDataInput, quantizationDataInput.length() - CodecUtil.footerLength()); quantizationDataInput, quantizationDataInput.length() - CodecUtil.footerLength());
long vectorDataLength = quantizedVectorData.getFilePointer() - vectorDataOffset;
CodecUtil.retrieveChecksum(quantizationDataInput); CodecUtil.retrieveChecksum(quantizationDataInput);
float quantile =
this.quantile == null
? calculateDefaultQuantile(fieldInfo.getVectorDimension())
: this.quantile;
writeMeta(
fieldInfo,
segmentWriteState.segmentInfo.maxDoc(),
vectorDataOffset,
vectorDataLength,
quantile,
mergedQuantizationState.getLowerQuantile(),
mergedQuantizationState.getUpperQuantile(),
docsWithField);
success = true; success = true;
final IndexInput finalQuantizationDataInput = quantizationDataInput; final IndexInput finalQuantizationDataInput = quantizationDataInput;
return new ScalarQuantizedCloseableRandomVectorScorerSupplier( return new ScalarQuantizedCloseableRandomVectorScorerSupplier(
@ -265,6 +428,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
IOUtils.close(finalQuantizationDataInput); IOUtils.close(finalQuantizationDataInput);
segmentWriteState.directory.deleteFile(tempQuantizedVectorData.getName()); segmentWriteState.directory.deleteFile(tempQuantizedVectorData.getName());
}, },
docsWithField.cardinality(),
new ScalarQuantizedRandomVectorScorerSupplier( new ScalarQuantizedRandomVectorScorerSupplier(
fieldInfo.getVectorSimilarityFunction(), fieldInfo.getVectorSimilarityFunction(),
mergedQuantizationState, mergedQuantizationState,
@ -427,43 +591,35 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
} }
@Override @Override
public long ramBytesUsed() { public void close() throws IOException {
return BASE_RAM_BYTES_USED; IOUtils.close(meta, quantizedVectorData, rawVectorDelegate);
} }
static class QuantizationFieldVectorWriter implements Accountable { static class FieldWriter extends FlatFieldVectorsWriter<float[]> {
private static final long SHALLOW_SIZE = private static final long SHALLOW_SIZE = shallowSizeOfInstance(FieldWriter.class);
shallowSizeOfInstance(QuantizationFieldVectorWriter.class);
private final int dim;
private final List<float[]> floatVectors; private final List<float[]> floatVectors;
private final boolean normalize; private final FieldInfo fieldInfo;
private final VectorSimilarityFunction vectorSimilarityFunction;
private final float quantile; private final float quantile;
private final InfoStream infoStream; private final InfoStream infoStream;
private final boolean normalize;
private float minQuantile = Float.POSITIVE_INFINITY; private float minQuantile = Float.POSITIVE_INFINITY;
private float maxQuantile = Float.NEGATIVE_INFINITY; private float maxQuantile = Float.NEGATIVE_INFINITY;
private boolean finished; private boolean finished;
private final DocsWithFieldSet docsWithField;
static QuantizationFieldVectorWriter create( @SuppressWarnings("unchecked")
FieldInfo fieldInfo, float quantile, InfoStream infoStream) { FieldWriter(
return new QuantizationFieldVectorWriter(
fieldInfo.getVectorDimension(),
quantile,
fieldInfo.getVectorSimilarityFunction(),
infoStream);
}
QuantizationFieldVectorWriter(
int dim,
float quantile, float quantile,
VectorSimilarityFunction vectorSimilarityFunction, FieldInfo fieldInfo,
InfoStream infoStream) { InfoStream infoStream,
this.dim = dim; KnnFieldVectorsWriter<?> indexWriter) {
super((KnnFieldVectorsWriter<float[]>) indexWriter);
this.quantile = quantile; this.quantile = quantile;
this.normalize = vectorSimilarityFunction == VectorSimilarityFunction.COSINE; this.fieldInfo = fieldInfo;
this.vectorSimilarityFunction = vectorSimilarityFunction; this.normalize = fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE;
this.floatVectors = new ArrayList<>(); this.floatVectors = new ArrayList<>();
this.infoStream = infoStream; this.infoStream = infoStream;
this.docsWithField = new DocsWithFieldSet();
} }
void finish() throws IOException { void finish() throws IOException {
@ -475,15 +631,17 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
return; return;
} }
ScalarQuantizer quantizer = ScalarQuantizer quantizer =
ScalarQuantizer.fromVectors(new FloatVectorWrapper(floatVectors, normalize), quantile); ScalarQuantizer.fromVectors(
new FloatVectorWrapper(
floatVectors,
fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE),
quantile);
minQuantile = quantizer.getLowerQuantile(); minQuantile = quantizer.getLowerQuantile();
maxQuantile = quantizer.getUpperQuantile(); maxQuantile = quantizer.getUpperQuantile();
if (infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) { if (infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) {
infoStream.message( infoStream.message(
QUANTIZED_VECTOR_COMPONENT, QUANTIZED_VECTOR_COMPONENT,
"quantized field=" "quantized field="
+ " dimension="
+ dim
+ " quantile=" + " quantile="
+ quantile + quantile
+ " minQuantile=" + " minQuantile="
@ -494,24 +652,6 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
finished = true; finished = true;
} }
public void addValue(float[] vectorValue) throws IOException {
floatVectors.add(vectorValue);
}
float getMinQuantile() {
assert finished;
return minQuantile;
}
float getMaxQuantile() {
assert finished;
return maxQuantile;
}
float getQuantile() {
return quantile;
}
ScalarQuantizer createQuantizer() { ScalarQuantizer createQuantizer() {
assert finished; assert finished;
return new ScalarQuantizer(minQuantile, maxQuantile, quantile); return new ScalarQuantizer(minQuantile, maxQuantile, quantile);
@ -519,8 +659,26 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
@Override @Override
public long ramBytesUsed() { public long ramBytesUsed() {
if (floatVectors.size() == 0) return SHALLOW_SIZE; long size = SHALLOW_SIZE;
return SHALLOW_SIZE + (long) floatVectors.size() * RamUsageEstimator.NUM_BYTES_OBJECT_REF; if (indexingDelegate != null) {
size += indexingDelegate.ramBytesUsed();
}
if (floatVectors.size() == 0) return size;
return size + (long) floatVectors.size() * RamUsageEstimator.NUM_BYTES_OBJECT_REF;
}
@Override
public void addValue(int docID, float[] vectorValue) throws IOException {
docsWithField.add(docID);
floatVectors.add(vectorValue);
if (indexingDelegate != null) {
indexingDelegate.addValue(docID, vectorValue);
}
}
@Override
public float[] copyValue(float[] vectorValue) {
throw new UnsupportedOperationException();
} }
} }
@ -613,6 +771,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
// Either our quantization parameters are way different than the merged ones // Either our quantization parameters are way different than the merged ones
// Or we have never been quantized. // Or we have never been quantized.
if (reader == null if (reader == null
|| reader.getQuantizationState(fieldInfo.name) == null
|| shouldRequantize(reader.getQuantizationState(fieldInfo.name), scalarQuantizer)) { || shouldRequantize(reader.getQuantizationState(fieldInfo.name), scalarQuantizer)) {
sub = sub =
new QuantizedByteVectorValueSub( new QuantizedByteVectorValueSub(
@ -702,6 +861,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
private final FloatVectorValues values; private final FloatVectorValues values;
private final ScalarQuantizer quantizer; private final ScalarQuantizer quantizer;
private final byte[] quantizedVector; private final byte[] quantizedVector;
private final float[] normalizedVector;
private float offsetValue = 0f; private float offsetValue = 0f;
private final VectorSimilarityFunction vectorSimilarityFunction; private final VectorSimilarityFunction vectorSimilarityFunction;
@ -714,6 +874,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
this.quantizer = quantizer; this.quantizer = quantizer;
this.quantizedVector = new byte[values.dimension()]; this.quantizedVector = new byte[values.dimension()];
this.vectorSimilarityFunction = vectorSimilarityFunction; this.vectorSimilarityFunction = vectorSimilarityFunction;
if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) {
this.normalizedVector = new float[values.dimension()];
} else {
this.normalizedVector = null;
}
} }
@Override @Override
@ -745,8 +910,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
public int nextDoc() throws IOException { public int nextDoc() throws IOException {
int doc = values.nextDoc(); int doc = values.nextDoc();
if (doc != NO_MORE_DOCS) { if (doc != NO_MORE_DOCS) {
offsetValue = quantize();
quantizer.quantize(values.vectorValue(), quantizedVector, vectorSimilarityFunction);
} }
return doc; return doc;
} }
@ -755,10 +919,21 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
public int advance(int target) throws IOException { public int advance(int target) throws IOException {
int doc = values.advance(target); int doc = values.advance(target);
if (doc != NO_MORE_DOCS) { if (doc != NO_MORE_DOCS) {
quantize();
}
return doc;
}
private void quantize() throws IOException {
if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) {
System.arraycopy(values.vectorValue(), 0, normalizedVector, 0, normalizedVector.length);
VectorUtil.l2normalize(normalizedVector);
offsetValue =
quantizer.quantize(normalizedVector, quantizedVector, vectorSimilarityFunction);
} else {
offsetValue = offsetValue =
quantizer.quantize(values.vectorValue(), quantizedVector, vectorSimilarityFunction); quantizer.quantize(values.vectorValue(), quantizedVector, vectorSimilarityFunction);
} }
return doc;
} }
} }
@ -767,11 +942,13 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
private final ScalarQuantizedRandomVectorScorerSupplier supplier; private final ScalarQuantizedRandomVectorScorerSupplier supplier;
private final Closeable onClose; private final Closeable onClose;
private final int numVectors;
ScalarQuantizedCloseableRandomVectorScorerSupplier( ScalarQuantizedCloseableRandomVectorScorerSupplier(
Closeable onClose, ScalarQuantizedRandomVectorScorerSupplier supplier) { Closeable onClose, int numVectors, ScalarQuantizedRandomVectorScorerSupplier supplier) {
this.onClose = onClose; this.onClose = onClose;
this.supplier = supplier; this.supplier = supplier;
this.numVectors = numVectors;
} }
@Override @Override
@ -788,6 +965,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
public void close() throws IOException { public void close() throws IOException {
onClose.close(); onClose.close();
} }
@Override
public int totalVectorCount() {
return numVectors;
}
} }
private static final class OffsetCorrectedQuantizedByteVectorValues private static final class OffsetCorrectedQuantizedByteVectorValues

View File

@ -98,8 +98,6 @@ abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVectorValue
} }
} }
abstract Bits getAcceptOrds(Bits acceptDocs);
static class DenseOffHeapVectorValues extends OffHeapQuantizedByteVectorValues { static class DenseOffHeapVectorValues extends OffHeapQuantizedByteVectorValues {
private int doc = -1; private int doc = -1;
@ -138,7 +136,7 @@ abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVectorValue
} }
@Override @Override
Bits getAcceptOrds(Bits acceptDocs) { public Bits getAcceptOrds(Bits acceptDocs) {
return acceptDocs; return acceptDocs;
} }
} }
@ -196,7 +194,7 @@ abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVectorValue
} }
@Override @Override
Bits getAcceptOrds(Bits acceptDocs) { public Bits getAcceptOrds(Bits acceptDocs) {
if (acceptDocs == null) { if (acceptDocs == null) {
return null; return null;
} }
@ -268,7 +266,7 @@ abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVectorValue
} }
@Override @Override
Bits getAcceptOrds(Bits acceptDocs) { public Bits getAcceptOrds(Bits acceptDocs) {
return null; return null;
} }
} }

View File

@ -25,7 +25,8 @@ import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorer;
/** Quantized vector scorer */ /** Quantized vector scorer */
final class ScalarQuantizedRandomVectorScorer implements RandomVectorScorer { final class ScalarQuantizedRandomVectorScorer
extends RandomVectorScorer.AbstractRandomVectorScorer<byte[]> {
private static float quantizeQuery( private static float quantizeQuery(
float[] query, float[] query,
@ -54,6 +55,7 @@ final class ScalarQuantizedRandomVectorScorer implements RandomVectorScorer {
RandomAccessQuantizedByteVectorValues values, RandomAccessQuantizedByteVectorValues values,
byte[] query, byte[] query,
float queryOffset) { float queryOffset) {
super(values);
this.quantizedQuery = query; this.quantizedQuery = query;
this.queryOffset = queryOffset; this.queryOffset = queryOffset;
this.similarity = similarityFunction; this.similarity = similarityFunction;
@ -65,6 +67,7 @@ final class ScalarQuantizedRandomVectorScorer implements RandomVectorScorer {
ScalarQuantizer scalarQuantizer, ScalarQuantizer scalarQuantizer,
RandomAccessQuantizedByteVectorValues values, RandomAccessQuantizedByteVectorValues values,
float[] query) { float[] query) {
super(values);
byte[] quantizedQuery = new byte[query.length]; byte[] quantizedQuery = new byte[query.length];
float correction = quantizeQuery(query, quantizedQuery, similarityFunction, scalarQuantizer); float correction = quantizeQuery(query, quantizedQuery, similarityFunction, scalarQuantizer);
this.quantizedQuery = quantizedQuery; this.quantizedQuery = quantizedQuery;

View File

@ -26,5 +26,6 @@ import java.io.Closeable;
* <p>NOTE: the {@link #copy()} returned {@link RandomVectorScorerSupplier} is not necessarily * <p>NOTE: the {@link #copy()} returned {@link RandomVectorScorerSupplier} is not necessarily
* closeable * closeable
*/ */
public interface CloseableRandomVectorScorerSupplier public interface CloseableRandomVectorScorerSupplier extends Closeable, RandomVectorScorerSupplier {
extends Closeable, RandomVectorScorerSupplier {} int totalVectorCount();
}

View File

@ -18,6 +18,7 @@
package org.apache.lucene.util.hnsw; package org.apache.lucene.util.hnsw;
import java.io.IOException; import java.io.IOException;
import org.apache.lucene.util.Bits;
/** /**
* Provides random access to vectors by dense ordinal. This interface is used by HNSW-based * Provides random access to vectors by dense ordinal. This interface is used by HNSW-based
@ -56,4 +57,14 @@ public interface RandomAccessVectorValues<T> {
default int ordToDoc(int ord) { default int ordToDoc(int ord) {
return ord; return ord;
} }
/**
* Returns the {@link Bits} representing live documents. By default, this is an identity function.
*
* @param acceptDocs the accept docs
* @return the accept docs
*/
default Bits getAcceptOrds(Bits acceptDocs) {
return acceptDocs;
}
} }

View File

@ -19,6 +19,7 @@ package org.apache.lucene.util.hnsw;
import java.io.IOException; import java.io.IOException;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.Bits;
/** A {@link RandomVectorScorer} for scoring random nodes in batches against an abstract query. */ /** A {@link RandomVectorScorer} for scoring random nodes in batches against an abstract query. */
public interface RandomVectorScorer { public interface RandomVectorScorer {
@ -30,6 +31,31 @@ public interface RandomVectorScorer {
*/ */
float score(int node) throws IOException; float score(int node) throws IOException;
/**
* @return the maximum possible ordinal for this scorer
*/
int maxOrd();
/**
* Translates vector ordinal to the correct document ID. By default, this is an identity function.
*
* @param ord the vector ordinal
* @return the document Id for that vector ordinal
*/
default int ordToDoc(int ord) {
return ord;
}
/**
* Returns the {@link Bits} representing live documents. By default, this is an identity function.
*
* @param acceptDocs the accept docs
* @return the accept docs
*/
default Bits getAcceptOrds(Bits acceptDocs) {
return acceptDocs;
}
/** /**
* Creates a default scorer for float vectors. * Creates a default scorer for float vectors.
* *
@ -53,7 +79,12 @@ public interface RandomVectorScorer {
+ " differs from field dimension: " + " differs from field dimension: "
+ vectors.dimension()); + vectors.dimension());
} }
return node -> similarityFunction.compare(query, vectors.vectorValue(node)); return new AbstractRandomVectorScorer<>(vectors) {
@Override
public float score(int node) throws IOException {
return similarityFunction.compare(query, vectors.vectorValue(node));
}
};
} }
/** /**
@ -79,6 +110,44 @@ public interface RandomVectorScorer {
+ " differs from field dimension: " + " differs from field dimension: "
+ vectors.dimension()); + vectors.dimension());
} }
return node -> similarityFunction.compare(query, vectors.vectorValue(node)); return new AbstractRandomVectorScorer<>(vectors) {
@Override
public float score(int node) throws IOException {
return similarityFunction.compare(query, vectors.vectorValue(node));
}
};
}
/**
* Creates a default scorer for random access vectors.
*
* @param <T> the type of the vector values
*/
abstract class AbstractRandomVectorScorer<T> implements RandomVectorScorer {
private final RandomAccessVectorValues<T> values;
/**
* Creates a new scorer for the given vector values.
*
* @param values the vector values
*/
public AbstractRandomVectorScorer(RandomAccessVectorValues<T> values) {
this.values = values;
}
@Override
public int maxOrd() {
return values.size();
}
@Override
public int ordToDoc(int ord) {
return values.ordToDoc(ord);
}
@Override
public Bits getAcceptOrds(Bits acceptDocs) {
return values.getAcceptOrds(acceptDocs);
}
} }
} }

View File

@ -87,8 +87,12 @@ public interface RandomVectorScorerSupplier {
@Override @Override
public RandomVectorScorer scorer(int ord) throws IOException { public RandomVectorScorer scorer(int ord) throws IOException {
return cand -> return new RandomVectorScorer.AbstractRandomVectorScorer<>(vectors) {
similarityFunction.compare(vectors1.vectorValue(ord), vectors2.vectorValue(cand)); @Override
public float score(int cand) throws IOException {
return similarityFunction.compare(vectors1.vectorValue(ord), vectors2.vectorValue(cand));
}
};
} }
@Override @Override
@ -115,8 +119,12 @@ public interface RandomVectorScorerSupplier {
@Override @Override
public RandomVectorScorer scorer(int ord) throws IOException { public RandomVectorScorer scorer(int ord) throws IOException {
return cand -> return new RandomVectorScorer.AbstractRandomVectorScorer<>(vectors) {
similarityFunction.compare(vectors1.vectorValue(ord), vectors2.vectorValue(cand)); @Override
public float score(int cand) throws IOException {
return similarityFunction.compare(vectors1.vectorValue(ord), vectors2.vectorValue(cand));
}
};
} }
@Override @Override

View File

@ -14,3 +14,4 @@
# limitations under the License. # limitations under the License.
org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat
org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat

View File

@ -47,10 +47,9 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
return new Lucene99Codec() { return new Lucene99Codec() {
@Override @Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) { public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene99HnswVectorsFormat( return new Lucene99HnswScalarQuantizedVectorsFormat(
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH, Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
new Lucene99ScalarQuantizedVectorsFormat());
} }
}; };
} }
@ -145,12 +144,11 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
new FilterCodec("foo", Codec.getDefault()) { new FilterCodec("foo", Codec.getDefault()) {
@Override @Override
public KnnVectorsFormat knnVectorsFormat() { public KnnVectorsFormat knnVectorsFormat() {
return new Lucene99HnswVectorsFormat( return new Lucene99HnswScalarQuantizedVectorsFormat(10, 20, 1, 0.9f, null);
10, 20, new Lucene99ScalarQuantizedVectorsFormat(0.9f));
} }
}; };
String expectedString = String expectedString =
"Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, maxConn=10, beamWidth=20, quantizer=Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, quantile=0.9))"; "Lucene99HnswScalarQuantizedVectorsFormat(name=Lucene99HnswScalarQuantizedVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, quantile=0.9, rawVectorFormat=Lucene99FlatVectorsFormat()))";
assertEquals(expectedString, customCodec.knnVectorsFormat().toString()); assertEquals(expectedString, customCodec.knnVectorsFormat().toString());
} }
} }

View File

@ -37,7 +37,7 @@ public class TestLucene99HnswVectorsFormat extends BaseKnnVectorsFormatTestCase
} }
}; };
String expectedString = String expectedString =
"Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, maxConn=10, beamWidth=20, quantizer=none)"; "Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99FlatVectorsFormat())";
assertEquals(expectedString, customCodec.knnVectorsFormat().toString()); assertEquals(expectedString, customCodec.knnVectorsFormat().toString());
} }

View File

@ -1,79 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene99;
import java.util.ArrayList;
import java.util.List;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.ScalarQuantizer;
public class TestLucene99ScalarQuantizedVectorsFormat extends LuceneTestCase {
public void testDefaultQuantile() {
float defaultQuantile = Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile(99);
assertEquals(0.99f, defaultQuantile, 1e-5);
defaultQuantile = Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile(1);
assertEquals(0.9f, defaultQuantile, 1e-5);
defaultQuantile =
Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile(Integer.MAX_VALUE - 2);
assertEquals(1.0f, defaultQuantile, 1e-5);
}
public void testLimits() {
expectThrows(
IllegalArgumentException.class, () -> new Lucene99ScalarQuantizedVectorsFormat(0.89f));
expectThrows(
IllegalArgumentException.class, () -> new Lucene99ScalarQuantizedVectorsFormat(1.1f));
}
public void testQuantileMergeWithMissing() {
List<ScalarQuantizer> quantiles = new ArrayList<>();
quantiles.add(new ScalarQuantizer(0.1f, 0.2f, 0.1f));
quantiles.add(new ScalarQuantizer(0.2f, 0.3f, 0.1f));
quantiles.add(new ScalarQuantizer(0.3f, 0.4f, 0.1f));
quantiles.add(null);
List<Integer> segmentSizes = List.of(1, 1, 1, 1);
assertNull(Lucene99ScalarQuantizedVectorsWriter.mergeQuantiles(quantiles, segmentSizes, 0.1f));
assertNull(Lucene99ScalarQuantizedVectorsWriter.mergeQuantiles(List.of(), List.of(), 0.1f));
}
public void testQuantileMerge() {
List<ScalarQuantizer> quantiles = new ArrayList<>();
quantiles.add(new ScalarQuantizer(0.1f, 0.2f, 0.1f));
quantiles.add(new ScalarQuantizer(0.2f, 0.3f, 0.1f));
quantiles.add(new ScalarQuantizer(0.3f, 0.4f, 0.1f));
List<Integer> segmentSizes = List.of(1, 1, 1);
ScalarQuantizer merged =
Lucene99ScalarQuantizedVectorsWriter.mergeQuantiles(quantiles, segmentSizes, 0.1f);
assertEquals(0.2f, merged.getLowerQuantile(), 1e-5);
assertEquals(0.3f, merged.getUpperQuantile(), 1e-5);
assertEquals(0.1f, merged.getConfiguredQuantile(), 1e-5);
}
public void testQuantileMergeWithDifferentSegmentSizes() {
List<ScalarQuantizer> quantiles = new ArrayList<>();
quantiles.add(new ScalarQuantizer(0.1f, 0.2f, 0.1f));
quantiles.add(new ScalarQuantizer(0.2f, 0.3f, 0.1f));
quantiles.add(new ScalarQuantizer(0.3f, 0.4f, 0.1f));
List<Integer> segmentSizes = List.of(1, 2, 3);
ScalarQuantizer merged =
Lucene99ScalarQuantizedVectorsWriter.mergeQuantiles(quantiles, segmentSizes, 0.1f);
assertEquals(0.2333333f, merged.getLowerQuantile(), 1e-5);
assertEquals(0.3333333f, merged.getUpperQuantile(), 1e-5);
assertEquals(0.1f, merged.getConfiguredQuantile(), 1e-5);
}
}

View File

@ -32,9 +32,9 @@ import java.util.concurrent.CountDownLatch;
import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field; import org.apache.lucene.document.Field;
@ -83,12 +83,7 @@ public class TestKnnGraph extends LuceneTestCase {
int similarity = random().nextInt(VectorSimilarityFunction.values().length - 1) + 1; int similarity = random().nextInt(VectorSimilarityFunction.values().length - 1) + 1;
similarityFunction = VectorSimilarityFunction.values()[similarity]; similarityFunction = VectorSimilarityFunction.values()[similarity];
vectorEncoding = randomVectorEncoding(); vectorEncoding = randomVectorEncoding();
boolean quantized = randomBoolean();
Lucene99ScalarQuantizedVectorsFormat scalarQuantizedVectorsFormat =
vectorEncoding.equals(VectorEncoding.FLOAT32) && randomBoolean()
? new Lucene99ScalarQuantizedVectorsFormat(1f)
: null;
codec = codec =
new FilterCodec(TestUtil.getDefaultCodec().getName(), TestUtil.getDefaultCodec()) { new FilterCodec(TestUtil.getDefaultCodec().getName(), TestUtil.getDefaultCodec()) {
@Override @Override
@ -96,14 +91,16 @@ public class TestKnnGraph extends LuceneTestCase {
return new PerFieldKnnVectorsFormat() { return new PerFieldKnnVectorsFormat() {
@Override @Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) { public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene99HnswVectorsFormat( return quantized
M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH, scalarQuantizedVectorsFormat); ? new Lucene99HnswScalarQuantizedVectorsFormat(
M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH)
: new Lucene99HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH);
} }
}; };
} }
}; };
if (vectorEncoding == VectorEncoding.FLOAT32 && scalarQuantizedVectorsFormat == null) { if (vectorEncoding == VectorEncoding.FLOAT32) {
float32Codec = codec; float32Codec = codec;
} else { } else {
float32Codec = float32Codec =

View File

@ -19,6 +19,7 @@ package org.apache.lucene.util.hnsw;
import java.io.IOException; import java.io.IOException;
import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.Bits;
public class TestNeighborArray extends LuceneTestCase { public class TestNeighborArray extends LuceneTestCase {
@ -190,7 +191,7 @@ public class TestNeighborArray extends LuceneTestCase {
neighbors.addOutOfOrder(7, Float.NaN); neighbors.addOutOfOrder(7, Float.NaN);
neighbors.addOutOfOrder(6, Float.NaN); neighbors.addOutOfOrder(6, Float.NaN);
neighbors.addOutOfOrder(4, Float.NaN); neighbors.addOutOfOrder(4, Float.NaN);
int[] unchecked = neighbors.sort(nodeId -> 7 - nodeId + 1); int[] unchecked = neighbors.sort((TestRandomVectorScorer) nodeId -> 7 - nodeId + 1);
assertArrayEquals(new int[] {0, 1, 2, 3, 4, 5, 6}, unchecked); assertArrayEquals(new int[] {0, 1, 2, 3, 4, 5, 6}, unchecked);
assertNodesEqual(new int[] {1, 2, 3, 4, 5, 6, 7}, neighbors); assertNodesEqual(new int[] {1, 2, 3, 4, 5, 6, 7}, neighbors);
assertScoresEqual(new float[] {7, 6, 5, 4, 3, 2, 1}, neighbors); assertScoresEqual(new float[] {7, 6, 5, 4, 3, 2, 1}, neighbors);
@ -206,7 +207,7 @@ public class TestNeighborArray extends LuceneTestCase {
neighbors.addOutOfOrder(17, Float.NaN); neighbors.addOutOfOrder(17, Float.NaN);
neighbors.addOutOfOrder(16, Float.NaN); neighbors.addOutOfOrder(16, Float.NaN);
neighbors.addOutOfOrder(14, Float.NaN); neighbors.addOutOfOrder(14, Float.NaN);
int[] unchecked = neighbors.sort(nodeId -> 7 - nodeId + 11); int[] unchecked = neighbors.sort((TestRandomVectorScorer) nodeId -> 7 - nodeId + 11);
assertArrayEquals(new int[] {0, 1, 2, 3, 4, 5, 6}, unchecked); assertArrayEquals(new int[] {0, 1, 2, 3, 4, 5, 6}, unchecked);
assertNodesEqual(new int[] {11, 12, 13, 14, 15, 16, 17}, neighbors); assertNodesEqual(new int[] {11, 12, 13, 14, 15, 16, 17}, neighbors);
assertScoresEqual(new float[] {7, 6, 5, 4, 3, 2, 1}, neighbors); assertScoresEqual(new float[] {7, 6, 5, 4, 3, 2, 1}, neighbors);
@ -223,4 +224,21 @@ public class TestNeighborArray extends LuceneTestCase {
assertEquals(nodes[i], neighbors.node[i]); assertEquals(nodes[i], neighbors.node[i]);
} }
} }
interface TestRandomVectorScorer extends RandomVectorScorer {
@Override
default int maxOrd() {
throw new UnsupportedOperationException();
}
@Override
default int ordToDoc(int ord) {
throw new UnsupportedOperationException();
}
@Override
default Bits getAcceptOrds(Bits acceptDocs) {
throw new UnsupportedOperationException();
}
}
} }