Reduce heap usage for knn index writers (#13538)

* Reduce heap usage for knn index writers

* iter

* fixing heap usage & adding changes

* javadocs
This commit is contained in:
Benjamin Trent 2024-07-10 10:28:48 -04:00 committed by GitHub
parent 026d661e5f
commit 428fdb5291
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 183 additions and 128 deletions

View File

@ -280,6 +280,8 @@ Optimizations
* GITHUB#13175: Stop double-checking priority queue inserts in some FacetCount classes. (Jakub Slowinski)
* GITHUB#13538: Slightly reduce heap usage for HNSW and scalar quantized vector writers. (Ben Trent)
Changes in runtime behavior
---------------------

View File

@ -17,7 +17,10 @@
package org.apache.lucene.codecs.hnsw;
import java.io.IOException;
import java.util.List;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.index.DocsWithFieldSet;
/**
* Vectors' writer for a field
@ -26,20 +29,25 @@ import org.apache.lucene.codecs.KnnFieldVectorsWriter;
* @lucene.experimental
*/
public abstract class FlatFieldVectorsWriter<T> extends KnnFieldVectorsWriter<T> {
/**
* The delegate to write to, can be null When non-null, all vectors seen should be written to the
* delegate along with being written to the flat vectors.
* @return a list of vectors to be written
*/
protected final KnnFieldVectorsWriter<T> indexingDelegate;
public abstract List<T> getVectors();
/**
* Sole constructor that expects some indexingDelegate. All vectors seen should be written to the
* delegate along with being written to the flat vectors.
* @return the docsWithFieldSet for the field writer
*/
public abstract DocsWithFieldSet getDocsWithFieldSet();
/**
* indicates that this writer is done and no new vectors are allowed to be added
*
* @param indexingDelegate the delegate to write to, can be null
* @throws IOException if an I/O error occurs
*/
protected FlatFieldVectorsWriter(KnnFieldVectorsWriter<T> indexingDelegate) {
this.indexingDelegate = indexingDelegate;
}
public abstract void finish() throws IOException;
/**
* @return true if the writer is done and no new vectors are allowed to be added
*/
public abstract boolean isFinished();
}

View File

@ -18,7 +18,6 @@
package org.apache.lucene.codecs.hnsw;
import java.io.IOException;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.MergeState;
@ -46,21 +45,14 @@ public abstract class FlatVectorsWriter extends KnnVectorsWriter {
}
/**
* Add a new field for indexing, allowing the user to provide a writer that the flat vectors
* writer can delegate to if additional indexing logic is required.
* Add a new field for indexing
*
* @param fieldInfo fieldInfo of the field to add
* @param indexWriter the writer to delegate to, can be null
* @return a writer for the field
* @throws IOException if an I/O error occurs when adding the field
*/
public abstract FlatFieldVectorsWriter<?> addField(
FieldInfo fieldInfo, KnnFieldVectorsWriter<?> indexWriter) throws IOException;
@Override
public FlatFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
return addField(fieldInfo, null);
}
public abstract FlatFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException;
/**
* Write the field for merging, providing a scorer over the newly merged flat vectors. This way

View File

@ -27,7 +27,6 @@ import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.List;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
@ -111,18 +110,12 @@ public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter {
}
@Override
public FlatFieldVectorsWriter<?> addField(
FieldInfo fieldInfo, KnnFieldVectorsWriter<?> indexWriter) throws IOException {
FieldWriter<?> newField = FieldWriter.create(fieldInfo, indexWriter);
public FlatFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
FieldWriter<?> newField = FieldWriter.create(fieldInfo);
fields.add(newField);
return newField;
}
@Override
public FlatFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
return addField(fieldInfo, null);
}
@Override
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
for (FieldWriter<?> field : fields) {
@ -131,6 +124,7 @@ public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter {
} else {
writeSortingField(field, maxDoc, sortMap);
}
field.finish();
}
}
@ -403,22 +397,20 @@ public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter {
private final int dim;
private final DocsWithFieldSet docsWithField;
private final List<T> vectors;
private boolean finished;
private int lastDocID = -1;
@SuppressWarnings("unchecked")
static FieldWriter<?> create(FieldInfo fieldInfo, KnnFieldVectorsWriter<?> indexWriter) {
static FieldWriter<?> create(FieldInfo fieldInfo) {
int dim = fieldInfo.getVectorDimension();
return switch (fieldInfo.getVectorEncoding()) {
case BYTE -> new Lucene99FlatVectorsWriter.FieldWriter<>(
fieldInfo, (KnnFieldVectorsWriter<byte[]>) indexWriter) {
case BYTE -> new Lucene99FlatVectorsWriter.FieldWriter<byte[]>(fieldInfo) {
@Override
public byte[] copyValue(byte[] value) {
return ArrayUtil.copyOfSubArray(value, 0, dim);
}
};
case FLOAT32 -> new Lucene99FlatVectorsWriter.FieldWriter<>(
fieldInfo, (KnnFieldVectorsWriter<float[]>) indexWriter) {
case FLOAT32 -> new Lucene99FlatVectorsWriter.FieldWriter<float[]>(fieldInfo) {
@Override
public float[] copyValue(float[] value) {
return ArrayUtil.copyOfSubArray(value, 0, dim);
@ -427,8 +419,8 @@ public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter {
};
}
FieldWriter(FieldInfo fieldInfo, KnnFieldVectorsWriter<T> indexWriter) {
super(indexWriter);
FieldWriter(FieldInfo fieldInfo) {
super();
this.fieldInfo = fieldInfo;
this.dim = fieldInfo.getVectorDimension();
this.docsWithField = new DocsWithFieldSet();
@ -437,6 +429,9 @@ public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter {
@Override
public void addValue(int docID, T vectorValue) throws IOException {
if (finished) {
throw new IllegalStateException("already finished, cannot add more values");
}
if (docID == lastDocID) {
throw new IllegalArgumentException(
"VectorValuesField \""
@ -448,17 +443,11 @@ public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter {
docsWithField.add(docID);
vectors.add(copy);
lastDocID = docID;
if (indexingDelegate != null) {
indexingDelegate.addValue(docID, copy);
}
}
@Override
public long ramBytesUsed() {
long size = SHALLOW_RAM_BYTES_USED;
if (indexingDelegate != null) {
size += indexingDelegate.ramBytesUsed();
}
if (vectors.size() == 0) return size;
return size
+ docsWithField.ramBytesUsed()
@ -468,6 +457,29 @@ public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter {
* fieldInfo.getVectorDimension()
* fieldInfo.getVectorEncoding().byteSize;
}
@Override
public List<T> getVectors() {
return vectors;
}
@Override
public DocsWithFieldSet getDocsWithFieldSet() {
return docsWithField;
}
@Override
public void finish() throws IOException {
if (finished) {
return;
}
this.finished = true;
}
@Override
public boolean isFinished() {
return finished;
}
}
static final class FlatCloseableRandomVectorScorerSupplier

View File

@ -24,9 +24,11 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.index.DocsWithFieldSet;
@ -130,12 +132,13 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
FieldWriter<?> newField =
FieldWriter.create(
flatVectorWriter.getFlatVectorScorer(),
flatVectorWriter.addField(fieldInfo),
fieldInfo,
M,
beamWidth,
segmentWriteState.infoStream);
fields.add(newField);
return flatVectorWriter.addField(fieldInfo, newField);
return newField;
}
@Override
@ -171,8 +174,10 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
@Override
public long ramBytesUsed() {
long total = SHALLOW_RAM_BYTES_USED;
// The vector delegate will also account for this writer's KnnFieldVectorsWriter objects
total += flatVectorWriter.ramBytesUsed();
for (FieldWriter<?> field : fields) {
// the field tracks the delegate field usage
total += field.ramBytesUsed();
}
return total;
}
@ -187,17 +192,19 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
fieldData.fieldInfo,
vectorIndexOffset,
vectorIndexLength,
fieldData.docsWithField.cardinality(),
fieldData.getDocsWithFieldSet().cardinality(),
graph,
graphLevelNodeOffsets);
}
private void writeSortingField(FieldWriter<?> fieldData, Sorter.DocMap sortMap)
throws IOException {
final int[] ordMap = new int[fieldData.docsWithField.cardinality()]; // new ord to old ord
final int[] oldOrdMap = new int[fieldData.docsWithField.cardinality()]; // old ord to new ord
final int[] ordMap =
new int[fieldData.getDocsWithFieldSet().cardinality()]; // new ord to old ord
final int[] oldOrdMap =
new int[fieldData.getDocsWithFieldSet().cardinality()]; // old ord to new ord
mapOldOrdToNewOrd(fieldData.docsWithField, sortMap, oldOrdMap, ordMap, null);
mapOldOrdToNewOrd(fieldData.getDocsWithFieldSet(), sortMap, oldOrdMap, ordMap, null);
// write graph
long vectorIndexOffset = vectorIndex.getFilePointer();
OnHeapHnswGraph graph = fieldData.getGraph();
@ -209,7 +216,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
fieldData.fieldInfo,
vectorIndexOffset,
vectorIndexLength,
fieldData.docsWithField.cardinality(),
fieldData.getDocsWithFieldSet().cardinality(),
mockGraph,
graphLevelNodeOffsets);
}
@ -521,42 +528,65 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
RamUsageEstimator.shallowSizeOfInstance(FieldWriter.class);
private final FieldInfo fieldInfo;
private final DocsWithFieldSet docsWithField;
private final List<T> vectors;
private final HnswGraphBuilder hnswGraphBuilder;
private int lastDocID = -1;
private int node = 0;
private final FlatFieldVectorsWriter<T> flatFieldVectorsWriter;
@SuppressWarnings("unchecked")
static FieldWriter<?> create(
FlatVectorsScorer scorer, FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream)
FlatVectorsScorer scorer,
FlatFieldVectorsWriter<?> flatFieldVectorsWriter,
FieldInfo fieldInfo,
int M,
int beamWidth,
InfoStream infoStream)
throws IOException {
return switch (fieldInfo.getVectorEncoding()) {
case BYTE -> new FieldWriter<byte[]>(scorer, fieldInfo, M, beamWidth, infoStream);
case FLOAT32 -> new FieldWriter<float[]>(scorer, fieldInfo, M, beamWidth, infoStream);
case BYTE -> new FieldWriter<>(
scorer,
(FlatFieldVectorsWriter<byte[]>) flatFieldVectorsWriter,
fieldInfo,
M,
beamWidth,
infoStream);
case FLOAT32 -> new FieldWriter<>(
scorer,
(FlatFieldVectorsWriter<float[]>) flatFieldVectorsWriter,
fieldInfo,
M,
beamWidth,
infoStream);
};
}
@SuppressWarnings("unchecked")
FieldWriter(
FlatVectorsScorer scorer, FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream)
FlatVectorsScorer scorer,
FlatFieldVectorsWriter<T> flatFieldVectorsWriter,
FieldInfo fieldInfo,
int M,
int beamWidth,
InfoStream infoStream)
throws IOException {
this.fieldInfo = fieldInfo;
this.docsWithField = new DocsWithFieldSet();
vectors = new ArrayList<>();
RandomVectorScorerSupplier scorerSupplier =
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> scorer.getRandomVectorScorerSupplier(
fieldInfo.getVectorSimilarityFunction(),
RandomAccessVectorValues.fromBytes(
(List<byte[]>) vectors, fieldInfo.getVectorDimension()));
(List<byte[]>) flatFieldVectorsWriter.getVectors(),
fieldInfo.getVectorDimension()));
case FLOAT32 -> scorer.getRandomVectorScorerSupplier(
fieldInfo.getVectorSimilarityFunction(),
RandomAccessVectorValues.fromFloats(
(List<float[]>) vectors, fieldInfo.getVectorDimension()));
(List<float[]>) flatFieldVectorsWriter.getVectors(),
fieldInfo.getVectorDimension()));
};
hnswGraphBuilder =
HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(infoStream);
this.flatFieldVectorsWriter = Objects.requireNonNull(flatFieldVectorsWriter);
}
@Override
@ -567,20 +597,23 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
+ fieldInfo.name
+ "\" appears more than once in this document (only one value is allowed per field)");
}
assert docID > lastDocID;
vectors.add(vectorValue);
docsWithField.add(docID);
flatFieldVectorsWriter.addValue(docID, vectorValue);
hnswGraphBuilder.addGraphNode(node);
node++;
lastDocID = docID;
}
public DocsWithFieldSet getDocsWithFieldSet() {
return flatFieldVectorsWriter.getDocsWithFieldSet();
}
@Override
public T copyValue(T vectorValue) {
throw new UnsupportedOperationException();
}
OnHeapHnswGraph getGraph() {
assert flatFieldVectorsWriter.isFinished();
if (node > 0) {
return hnswGraphBuilder.getGraph();
} else {
@ -591,9 +624,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
@Override
public long ramBytesUsed() {
return SHALLOW_SIZE
+ docsWithField.ramBytesUsed()
+ (long) vectors.size()
* (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
+ flatFieldVectorsWriter.ramBytesUsed()
+ hnswGraphBuilder.getGraph().ramBytesUsed();
}
}

View File

@ -30,8 +30,8 @@ import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
@ -56,7 +56,6 @@ import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
@ -195,8 +194,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
}
@Override
public FlatFieldVectorsWriter<?> addField(
FieldInfo fieldInfo, KnnFieldVectorsWriter<?> indexWriter) throws IOException {
public FlatFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
FlatFieldVectorsWriter<?> rawVectorDelegate = this.rawVectorDelegate.addField(fieldInfo);
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
if (bits <= 4 && fieldInfo.getVectorDimension() % 2 != 0) {
throw new IllegalArgumentException(
@ -205,6 +204,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
+ " is not supported for odd vector dimensions; vector dimension="
+ fieldInfo.getVectorDimension());
}
@SuppressWarnings("unchecked")
FieldWriter quantizedWriter =
new FieldWriter(
confidenceInterval,
@ -212,11 +212,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
compress,
fieldInfo,
segmentWriteState.infoStream,
indexWriter);
(FlatFieldVectorsWriter<float[]>) rawVectorDelegate);
fields.add(quantizedWriter);
indexWriter = quantizedWriter;
return quantizedWriter;
}
return rawVectorDelegate.addField(fieldInfo, indexWriter);
return rawVectorDelegate;
}
@Override
@ -270,12 +270,13 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
rawVectorDelegate.flush(maxDoc, sortMap);
for (FieldWriter field : fields) {
field.finish();
ScalarQuantizer quantizer = field.createQuantizer();
if (sortMap == null) {
writeField(field, maxDoc);
writeField(field, maxDoc, quantizer);
} else {
writeSortingField(field, maxDoc, sortMap);
writeSortingField(field, maxDoc, sortMap, quantizer);
}
field.finish();
}
}
@ -299,15 +300,18 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
@Override
public long ramBytesUsed() {
long total = SHALLOW_RAM_BYTES_USED;
// The vector delegate will also account for this writer's KnnFieldVectorsWriter objects
total += rawVectorDelegate.ramBytesUsed();
for (FieldWriter field : fields) {
// the field tracks the delegate field usage
total += field.ramBytesUsed();
}
return total;
}
private void writeField(FieldWriter fieldData, int maxDoc) throws IOException {
private void writeField(FieldWriter fieldData, int maxDoc, ScalarQuantizer scalarQuantizer)
throws IOException {
// write vector values
long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES);
writeQuantizedVectors(fieldData);
writeQuantizedVectors(fieldData, scalarQuantizer);
long vectorDataLength = quantizedVectorData.getFilePointer() - vectorDataOffset;
writeMeta(
@ -318,9 +322,9 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
confidenceInterval,
bits,
compress,
fieldData.minQuantile,
fieldData.maxQuantile,
fieldData.docsWithField);
scalarQuantizer.getLowerQuantile(),
scalarQuantizer.getUpperQuantile(),
fieldData.getDocsWithFieldSet());
}
private void writeMeta(
@ -365,8 +369,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
DIRECT_MONOTONIC_BLOCK_SHIFT, meta, quantizedVectorData, count, maxDoc, docsWithField);
}
private void writeQuantizedVectors(FieldWriter fieldData) throws IOException {
ScalarQuantizer scalarQuantizer = fieldData.createQuantizer();
private void writeQuantizedVectors(FieldWriter fieldData, ScalarQuantizer scalarQuantizer)
throws IOException {
byte[] vector = new byte[fieldData.fieldInfo.getVectorDimension()];
byte[] compressedVector =
fieldData.compress
@ -375,7 +379,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
: null;
final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
float[] copy = fieldData.normalize ? new float[fieldData.fieldInfo.getVectorDimension()] : null;
for (float[] v : fieldData.floatVectors) {
assert fieldData.getVectors().isEmpty() || scalarQuantizer != null;
for (float[] v : fieldData.getVectors()) {
if (fieldData.normalize) {
System.arraycopy(v, 0, copy, 0, copy.length);
VectorUtil.l2normalize(copy);
@ -396,16 +401,18 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
}
}
private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocMap sortMap)
private void writeSortingField(
FieldWriter fieldData, int maxDoc, Sorter.DocMap sortMap, ScalarQuantizer scalarQuantizer)
throws IOException {
final int[] ordMap = new int[fieldData.docsWithField.cardinality()]; // new ord to old ord
final int[] ordMap =
new int[fieldData.getDocsWithFieldSet().cardinality()]; // new ord to old ord
DocsWithFieldSet newDocsWithField = new DocsWithFieldSet();
mapOldOrdToNewOrd(fieldData.docsWithField, sortMap, null, ordMap, newDocsWithField);
mapOldOrdToNewOrd(fieldData.getDocsWithFieldSet(), sortMap, null, ordMap, newDocsWithField);
// write vector values
long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES);
writeSortedQuantizedVectors(fieldData, ordMap);
writeSortedQuantizedVectors(fieldData, ordMap, scalarQuantizer);
long quantizedVectorLength = quantizedVectorData.getFilePointer() - vectorDataOffset;
writeMeta(
fieldData.fieldInfo,
@ -415,13 +422,13 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
confidenceInterval,
bits,
compress,
fieldData.minQuantile,
fieldData.maxQuantile,
scalarQuantizer.getLowerQuantile(),
scalarQuantizer.getUpperQuantile(),
newDocsWithField);
}
private void writeSortedQuantizedVectors(FieldWriter fieldData, int[] ordMap) throws IOException {
ScalarQuantizer scalarQuantizer = fieldData.createQuantizer();
private void writeSortedQuantizedVectors(
FieldWriter fieldData, int[] ordMap, ScalarQuantizer scalarQuantizer) throws IOException {
byte[] vector = new byte[fieldData.fieldInfo.getVectorDimension()];
byte[] compressedVector =
fieldData.compress
@ -431,7 +438,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
float[] copy = fieldData.normalize ? new float[fieldData.fieldInfo.getVectorDimension()] : null;
for (int ordinal : ordMap) {
float[] v = fieldData.floatVectors.get(ordinal);
float[] v = fieldData.getVectors().get(ordinal);
if (fieldData.normalize) {
System.arraycopy(v, 0, copy, 0, copy.length);
VectorUtil.l2normalize(copy);
@ -744,44 +751,51 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
static class FieldWriter extends FlatFieldVectorsWriter<float[]> {
private static final long SHALLOW_SIZE = shallowSizeOfInstance(FieldWriter.class);
private final List<float[]> floatVectors;
private final FieldInfo fieldInfo;
private final Float confidenceInterval;
private final byte bits;
private final boolean compress;
private final InfoStream infoStream;
private final boolean normalize;
private float minQuantile = Float.POSITIVE_INFINITY;
private float maxQuantile = Float.NEGATIVE_INFINITY;
private boolean finished;
private final DocsWithFieldSet docsWithField;
private final FlatFieldVectorsWriter<float[]> flatFieldVectorsWriter;
@SuppressWarnings("unchecked")
FieldWriter(
Float confidenceInterval,
byte bits,
boolean compress,
FieldInfo fieldInfo,
InfoStream infoStream,
KnnFieldVectorsWriter<?> indexWriter) {
super((KnnFieldVectorsWriter<float[]>) indexWriter);
FlatFieldVectorsWriter<float[]> indexWriter) {
super();
this.confidenceInterval = confidenceInterval;
this.bits = bits;
this.fieldInfo = fieldInfo;
this.normalize = fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE;
this.floatVectors = new ArrayList<>();
this.infoStream = infoStream;
this.docsWithField = new DocsWithFieldSet();
this.compress = compress;
this.flatFieldVectorsWriter = Objects.requireNonNull(indexWriter);
}
void finish() throws IOException {
@Override
public boolean isFinished() {
return finished && flatFieldVectorsWriter.isFinished();
}
@Override
public void finish() throws IOException {
if (finished) {
return;
}
assert flatFieldVectorsWriter.isFinished();
finished = true;
}
ScalarQuantizer createQuantizer() throws IOException {
assert flatFieldVectorsWriter.isFinished();
List<float[]> floatVectors = flatFieldVectorsWriter.getVectors();
if (floatVectors.size() == 0) {
finished = true;
return;
return new ScalarQuantizer(0, 0, bits);
}
FloatVectorValues floatVectorValues = new FloatVectorWrapper(floatVectors, normalize);
ScalarQuantizer quantizer =
@ -791,8 +805,6 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
fieldInfo.getVectorSimilarityFunction(),
confidenceInterval,
bits);
minQuantile = quantizer.getLowerQuantile();
maxQuantile = quantizer.getUpperQuantile();
if (infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) {
infoStream.message(
QUANTIZED_VECTOR_COMPONENT,
@ -802,41 +814,39 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
+ " bits="
+ bits
+ " minQuantile="
+ minQuantile
+ quantizer.getLowerQuantile()
+ " maxQuantile="
+ maxQuantile);
+ quantizer.getUpperQuantile());
}
finished = true;
}
ScalarQuantizer createQuantizer() {
assert finished;
return new ScalarQuantizer(minQuantile, maxQuantile, bits);
return quantizer;
}
@Override
public long ramBytesUsed() {
long size = SHALLOW_SIZE;
if (indexingDelegate != null) {
size += indexingDelegate.ramBytesUsed();
}
if (floatVectors.size() == 0) return size;
return size + (long) floatVectors.size() * RamUsageEstimator.NUM_BYTES_OBJECT_REF;
size += flatFieldVectorsWriter.ramBytesUsed();
return size;
}
@Override
public void addValue(int docID, float[] vectorValue) throws IOException {
docsWithField.add(docID);
floatVectors.add(vectorValue);
if (indexingDelegate != null) {
indexingDelegate.addValue(docID, vectorValue);
}
flatFieldVectorsWriter.addValue(docID, vectorValue);
}
@Override
public float[] copyValue(float[] vectorValue) {
throw new UnsupportedOperationException();
}
@Override
public List<float[]> getVectors() {
return flatFieldVectorsWriter.getVectors();
}
@Override
public DocsWithFieldSet getDocsWithFieldSet() {
return flatFieldVectorsWriter.getDocsWithFieldSet();
}
}
static class FloatVectorWrapper extends FloatVectorValues {