mirror of https://github.com/apache/lucene.git
Reduce heap usage for knn index writers (#13538)
* Reduce heap usage for knn index writers * iter * fixing heap usage & adding changes * javadocs
This commit is contained in:
parent
026d661e5f
commit
428fdb5291
|
@ -280,6 +280,8 @@ Optimizations
|
|||
|
||||
* GITHUB#13175: Stop double-checking priority queue inserts in some FacetCount classes. (Jakub Slowinski)
|
||||
|
||||
* GITHUB#13538: Slightly reduce heap usage for HNSW and scalar quantized vector writers. (Ben Trent)
|
||||
|
||||
Changes in runtime behavior
|
||||
---------------------
|
||||
|
||||
|
|
|
@ -17,7 +17,10 @@
|
|||
|
||||
package org.apache.lucene.codecs.hnsw;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||
import org.apache.lucene.index.DocsWithFieldSet;
|
||||
|
||||
/**
|
||||
* Vectors' writer for a field
|
||||
|
@ -26,20 +29,25 @@ import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
|||
* @lucene.experimental
|
||||
*/
|
||||
public abstract class FlatFieldVectorsWriter<T> extends KnnFieldVectorsWriter<T> {
|
||||
|
||||
/**
|
||||
* The delegate to write to, can be null When non-null, all vectors seen should be written to the
|
||||
* delegate along with being written to the flat vectors.
|
||||
* @return a list of vectors to be written
|
||||
*/
|
||||
protected final KnnFieldVectorsWriter<T> indexingDelegate;
|
||||
public abstract List<T> getVectors();
|
||||
|
||||
/**
|
||||
* Sole constructor that expects some indexingDelegate. All vectors seen should be written to the
|
||||
* delegate along with being written to the flat vectors.
|
||||
* @return the docsWithFieldSet for the field writer
|
||||
*/
|
||||
public abstract DocsWithFieldSet getDocsWithFieldSet();
|
||||
|
||||
/**
|
||||
* indicates that this writer is done and no new vectors are allowed to be added
|
||||
*
|
||||
* @param indexingDelegate the delegate to write to, can be null
|
||||
* @throws IOException if an I/O error occurs
|
||||
*/
|
||||
protected FlatFieldVectorsWriter(KnnFieldVectorsWriter<T> indexingDelegate) {
|
||||
this.indexingDelegate = indexingDelegate;
|
||||
}
|
||||
public abstract void finish() throws IOException;
|
||||
|
||||
/**
|
||||
* @return true if the writer is done and no new vectors are allowed to be added
|
||||
*/
|
||||
public abstract boolean isFinished();
|
||||
}
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
package org.apache.lucene.codecs.hnsw;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.MergeState;
|
||||
|
@ -46,21 +45,14 @@ public abstract class FlatVectorsWriter extends KnnVectorsWriter {
|
|||
}
|
||||
|
||||
/**
|
||||
* Add a new field for indexing, allowing the user to provide a writer that the flat vectors
|
||||
* writer can delegate to if additional indexing logic is required.
|
||||
* Add a new field for indexing
|
||||
*
|
||||
* @param fieldInfo fieldInfo of the field to add
|
||||
* @param indexWriter the writer to delegate to, can be null
|
||||
* @return a writer for the field
|
||||
* @throws IOException if an I/O error occurs when adding the field
|
||||
*/
|
||||
public abstract FlatFieldVectorsWriter<?> addField(
|
||||
FieldInfo fieldInfo, KnnFieldVectorsWriter<?> indexWriter) throws IOException;
|
||||
|
||||
@Override
|
||||
public FlatFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
|
||||
return addField(fieldInfo, null);
|
||||
}
|
||||
public abstract FlatFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException;
|
||||
|
||||
/**
|
||||
* Write the field for merging, providing a scorer over the newly merged flat vectors. This way
|
||||
|
|
|
@ -27,7 +27,6 @@ import java.nio.ByteOrder;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
|
||||
|
@ -111,18 +110,12 @@ public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter {
|
|||
}
|
||||
|
||||
@Override
|
||||
public FlatFieldVectorsWriter<?> addField(
|
||||
FieldInfo fieldInfo, KnnFieldVectorsWriter<?> indexWriter) throws IOException {
|
||||
FieldWriter<?> newField = FieldWriter.create(fieldInfo, indexWriter);
|
||||
public FlatFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
|
||||
FieldWriter<?> newField = FieldWriter.create(fieldInfo);
|
||||
fields.add(newField);
|
||||
return newField;
|
||||
}
|
||||
|
||||
@Override
|
||||
public FlatFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
|
||||
return addField(fieldInfo, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
|
||||
for (FieldWriter<?> field : fields) {
|
||||
|
@ -131,6 +124,7 @@ public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter {
|
|||
} else {
|
||||
writeSortingField(field, maxDoc, sortMap);
|
||||
}
|
||||
field.finish();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -403,22 +397,20 @@ public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter {
|
|||
private final int dim;
|
||||
private final DocsWithFieldSet docsWithField;
|
||||
private final List<T> vectors;
|
||||
private boolean finished;
|
||||
|
||||
private int lastDocID = -1;
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
static FieldWriter<?> create(FieldInfo fieldInfo, KnnFieldVectorsWriter<?> indexWriter) {
|
||||
static FieldWriter<?> create(FieldInfo fieldInfo) {
|
||||
int dim = fieldInfo.getVectorDimension();
|
||||
return switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE -> new Lucene99FlatVectorsWriter.FieldWriter<>(
|
||||
fieldInfo, (KnnFieldVectorsWriter<byte[]>) indexWriter) {
|
||||
case BYTE -> new Lucene99FlatVectorsWriter.FieldWriter<byte[]>(fieldInfo) {
|
||||
@Override
|
||||
public byte[] copyValue(byte[] value) {
|
||||
return ArrayUtil.copyOfSubArray(value, 0, dim);
|
||||
}
|
||||
};
|
||||
case FLOAT32 -> new Lucene99FlatVectorsWriter.FieldWriter<>(
|
||||
fieldInfo, (KnnFieldVectorsWriter<float[]>) indexWriter) {
|
||||
case FLOAT32 -> new Lucene99FlatVectorsWriter.FieldWriter<float[]>(fieldInfo) {
|
||||
@Override
|
||||
public float[] copyValue(float[] value) {
|
||||
return ArrayUtil.copyOfSubArray(value, 0, dim);
|
||||
|
@ -427,8 +419,8 @@ public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter {
|
|||
};
|
||||
}
|
||||
|
||||
FieldWriter(FieldInfo fieldInfo, KnnFieldVectorsWriter<T> indexWriter) {
|
||||
super(indexWriter);
|
||||
FieldWriter(FieldInfo fieldInfo) {
|
||||
super();
|
||||
this.fieldInfo = fieldInfo;
|
||||
this.dim = fieldInfo.getVectorDimension();
|
||||
this.docsWithField = new DocsWithFieldSet();
|
||||
|
@ -437,6 +429,9 @@ public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter {
|
|||
|
||||
@Override
|
||||
public void addValue(int docID, T vectorValue) throws IOException {
|
||||
if (finished) {
|
||||
throw new IllegalStateException("already finished, cannot add more values");
|
||||
}
|
||||
if (docID == lastDocID) {
|
||||
throw new IllegalArgumentException(
|
||||
"VectorValuesField \""
|
||||
|
@ -448,17 +443,11 @@ public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter {
|
|||
docsWithField.add(docID);
|
||||
vectors.add(copy);
|
||||
lastDocID = docID;
|
||||
if (indexingDelegate != null) {
|
||||
indexingDelegate.addValue(docID, copy);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
long size = SHALLOW_RAM_BYTES_USED;
|
||||
if (indexingDelegate != null) {
|
||||
size += indexingDelegate.ramBytesUsed();
|
||||
}
|
||||
if (vectors.size() == 0) return size;
|
||||
return size
|
||||
+ docsWithField.ramBytesUsed()
|
||||
|
@ -468,6 +457,29 @@ public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter {
|
|||
* fieldInfo.getVectorDimension()
|
||||
* fieldInfo.getVectorEncoding().byteSize;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<T> getVectors() {
|
||||
return vectors;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocsWithFieldSet getDocsWithFieldSet() {
|
||||
return docsWithField;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void finish() throws IOException {
|
||||
if (finished) {
|
||||
return;
|
||||
}
|
||||
this.finished = true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isFinished() {
|
||||
return finished;
|
||||
}
|
||||
}
|
||||
|
||||
static final class FlatCloseableRandomVectorScorerSupplier
|
||||
|
|
|
@ -24,9 +24,11 @@ import java.io.IOException;
|
|||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
|
||||
import org.apache.lucene.index.DocsWithFieldSet;
|
||||
|
@ -130,12 +132,13 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
|||
FieldWriter<?> newField =
|
||||
FieldWriter.create(
|
||||
flatVectorWriter.getFlatVectorScorer(),
|
||||
flatVectorWriter.addField(fieldInfo),
|
||||
fieldInfo,
|
||||
M,
|
||||
beamWidth,
|
||||
segmentWriteState.infoStream);
|
||||
fields.add(newField);
|
||||
return flatVectorWriter.addField(fieldInfo, newField);
|
||||
return newField;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -171,8 +174,10 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
|||
@Override
|
||||
public long ramBytesUsed() {
|
||||
long total = SHALLOW_RAM_BYTES_USED;
|
||||
// The vector delegate will also account for this writer's KnnFieldVectorsWriter objects
|
||||
total += flatVectorWriter.ramBytesUsed();
|
||||
for (FieldWriter<?> field : fields) {
|
||||
// the field tracks the delegate field usage
|
||||
total += field.ramBytesUsed();
|
||||
}
|
||||
return total;
|
||||
}
|
||||
|
||||
|
@ -187,17 +192,19 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
|||
fieldData.fieldInfo,
|
||||
vectorIndexOffset,
|
||||
vectorIndexLength,
|
||||
fieldData.docsWithField.cardinality(),
|
||||
fieldData.getDocsWithFieldSet().cardinality(),
|
||||
graph,
|
||||
graphLevelNodeOffsets);
|
||||
}
|
||||
|
||||
private void writeSortingField(FieldWriter<?> fieldData, Sorter.DocMap sortMap)
|
||||
throws IOException {
|
||||
final int[] ordMap = new int[fieldData.docsWithField.cardinality()]; // new ord to old ord
|
||||
final int[] oldOrdMap = new int[fieldData.docsWithField.cardinality()]; // old ord to new ord
|
||||
final int[] ordMap =
|
||||
new int[fieldData.getDocsWithFieldSet().cardinality()]; // new ord to old ord
|
||||
final int[] oldOrdMap =
|
||||
new int[fieldData.getDocsWithFieldSet().cardinality()]; // old ord to new ord
|
||||
|
||||
mapOldOrdToNewOrd(fieldData.docsWithField, sortMap, oldOrdMap, ordMap, null);
|
||||
mapOldOrdToNewOrd(fieldData.getDocsWithFieldSet(), sortMap, oldOrdMap, ordMap, null);
|
||||
// write graph
|
||||
long vectorIndexOffset = vectorIndex.getFilePointer();
|
||||
OnHeapHnswGraph graph = fieldData.getGraph();
|
||||
|
@ -209,7 +216,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
|||
fieldData.fieldInfo,
|
||||
vectorIndexOffset,
|
||||
vectorIndexLength,
|
||||
fieldData.docsWithField.cardinality(),
|
||||
fieldData.getDocsWithFieldSet().cardinality(),
|
||||
mockGraph,
|
||||
graphLevelNodeOffsets);
|
||||
}
|
||||
|
@ -521,42 +528,65 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
|||
RamUsageEstimator.shallowSizeOfInstance(FieldWriter.class);
|
||||
|
||||
private final FieldInfo fieldInfo;
|
||||
private final DocsWithFieldSet docsWithField;
|
||||
private final List<T> vectors;
|
||||
private final HnswGraphBuilder hnswGraphBuilder;
|
||||
private int lastDocID = -1;
|
||||
private int node = 0;
|
||||
private final FlatFieldVectorsWriter<T> flatFieldVectorsWriter;
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
static FieldWriter<?> create(
|
||||
FlatVectorsScorer scorer, FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream)
|
||||
FlatVectorsScorer scorer,
|
||||
FlatFieldVectorsWriter<?> flatFieldVectorsWriter,
|
||||
FieldInfo fieldInfo,
|
||||
int M,
|
||||
int beamWidth,
|
||||
InfoStream infoStream)
|
||||
throws IOException {
|
||||
return switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE -> new FieldWriter<byte[]>(scorer, fieldInfo, M, beamWidth, infoStream);
|
||||
case FLOAT32 -> new FieldWriter<float[]>(scorer, fieldInfo, M, beamWidth, infoStream);
|
||||
case BYTE -> new FieldWriter<>(
|
||||
scorer,
|
||||
(FlatFieldVectorsWriter<byte[]>) flatFieldVectorsWriter,
|
||||
fieldInfo,
|
||||
M,
|
||||
beamWidth,
|
||||
infoStream);
|
||||
case FLOAT32 -> new FieldWriter<>(
|
||||
scorer,
|
||||
(FlatFieldVectorsWriter<float[]>) flatFieldVectorsWriter,
|
||||
fieldInfo,
|
||||
M,
|
||||
beamWidth,
|
||||
infoStream);
|
||||
};
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
FieldWriter(
|
||||
FlatVectorsScorer scorer, FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream)
|
||||
FlatVectorsScorer scorer,
|
||||
FlatFieldVectorsWriter<T> flatFieldVectorsWriter,
|
||||
FieldInfo fieldInfo,
|
||||
int M,
|
||||
int beamWidth,
|
||||
InfoStream infoStream)
|
||||
throws IOException {
|
||||
this.fieldInfo = fieldInfo;
|
||||
this.docsWithField = new DocsWithFieldSet();
|
||||
vectors = new ArrayList<>();
|
||||
RandomVectorScorerSupplier scorerSupplier =
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE -> scorer.getRandomVectorScorerSupplier(
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
RandomAccessVectorValues.fromBytes(
|
||||
(List<byte[]>) vectors, fieldInfo.getVectorDimension()));
|
||||
(List<byte[]>) flatFieldVectorsWriter.getVectors(),
|
||||
fieldInfo.getVectorDimension()));
|
||||
case FLOAT32 -> scorer.getRandomVectorScorerSupplier(
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
RandomAccessVectorValues.fromFloats(
|
||||
(List<float[]>) vectors, fieldInfo.getVectorDimension()));
|
||||
(List<float[]>) flatFieldVectorsWriter.getVectors(),
|
||||
fieldInfo.getVectorDimension()));
|
||||
};
|
||||
hnswGraphBuilder =
|
||||
HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed);
|
||||
hnswGraphBuilder.setInfoStream(infoStream);
|
||||
this.flatFieldVectorsWriter = Objects.requireNonNull(flatFieldVectorsWriter);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -567,20 +597,23 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
|||
+ fieldInfo.name
|
||||
+ "\" appears more than once in this document (only one value is allowed per field)");
|
||||
}
|
||||
assert docID > lastDocID;
|
||||
vectors.add(vectorValue);
|
||||
docsWithField.add(docID);
|
||||
flatFieldVectorsWriter.addValue(docID, vectorValue);
|
||||
hnswGraphBuilder.addGraphNode(node);
|
||||
node++;
|
||||
lastDocID = docID;
|
||||
}
|
||||
|
||||
public DocsWithFieldSet getDocsWithFieldSet() {
|
||||
return flatFieldVectorsWriter.getDocsWithFieldSet();
|
||||
}
|
||||
|
||||
@Override
|
||||
public T copyValue(T vectorValue) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
OnHeapHnswGraph getGraph() {
|
||||
assert flatFieldVectorsWriter.isFinished();
|
||||
if (node > 0) {
|
||||
return hnswGraphBuilder.getGraph();
|
||||
} else {
|
||||
|
@ -591,9 +624,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
|||
@Override
|
||||
public long ramBytesUsed() {
|
||||
return SHALLOW_SIZE
|
||||
+ docsWithField.ramBytesUsed()
|
||||
+ (long) vectors.size()
|
||||
* (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
|
||||
+ flatFieldVectorsWriter.ramBytesUsed()
|
||||
+ hnswGraphBuilder.getGraph().ramBytesUsed();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,8 +30,8 @@ import java.nio.ByteBuffer;
|
|||
import java.nio.ByteOrder;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
|
||||
|
@ -56,7 +56,6 @@ import org.apache.lucene.store.IndexInput;
|
|||
import org.apache.lucene.store.IndexOutput;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.InfoStream;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
|
@ -195,8 +194,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
}
|
||||
|
||||
@Override
|
||||
public FlatFieldVectorsWriter<?> addField(
|
||||
FieldInfo fieldInfo, KnnFieldVectorsWriter<?> indexWriter) throws IOException {
|
||||
public FlatFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
|
||||
FlatFieldVectorsWriter<?> rawVectorDelegate = this.rawVectorDelegate.addField(fieldInfo);
|
||||
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
|
||||
if (bits <= 4 && fieldInfo.getVectorDimension() % 2 != 0) {
|
||||
throw new IllegalArgumentException(
|
||||
|
@ -205,6 +204,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
+ " is not supported for odd vector dimensions; vector dimension="
|
||||
+ fieldInfo.getVectorDimension());
|
||||
}
|
||||
@SuppressWarnings("unchecked")
|
||||
FieldWriter quantizedWriter =
|
||||
new FieldWriter(
|
||||
confidenceInterval,
|
||||
|
@ -212,11 +212,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
compress,
|
||||
fieldInfo,
|
||||
segmentWriteState.infoStream,
|
||||
indexWriter);
|
||||
(FlatFieldVectorsWriter<float[]>) rawVectorDelegate);
|
||||
fields.add(quantizedWriter);
|
||||
indexWriter = quantizedWriter;
|
||||
return quantizedWriter;
|
||||
}
|
||||
return rawVectorDelegate.addField(fieldInfo, indexWriter);
|
||||
return rawVectorDelegate;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -270,12 +270,13 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
|
||||
rawVectorDelegate.flush(maxDoc, sortMap);
|
||||
for (FieldWriter field : fields) {
|
||||
field.finish();
|
||||
ScalarQuantizer quantizer = field.createQuantizer();
|
||||
if (sortMap == null) {
|
||||
writeField(field, maxDoc);
|
||||
writeField(field, maxDoc, quantizer);
|
||||
} else {
|
||||
writeSortingField(field, maxDoc, sortMap);
|
||||
writeSortingField(field, maxDoc, sortMap, quantizer);
|
||||
}
|
||||
field.finish();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -299,15 +300,18 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
@Override
|
||||
public long ramBytesUsed() {
|
||||
long total = SHALLOW_RAM_BYTES_USED;
|
||||
// The vector delegate will also account for this writer's KnnFieldVectorsWriter objects
|
||||
total += rawVectorDelegate.ramBytesUsed();
|
||||
for (FieldWriter field : fields) {
|
||||
// the field tracks the delegate field usage
|
||||
total += field.ramBytesUsed();
|
||||
}
|
||||
return total;
|
||||
}
|
||||
|
||||
private void writeField(FieldWriter fieldData, int maxDoc) throws IOException {
|
||||
private void writeField(FieldWriter fieldData, int maxDoc, ScalarQuantizer scalarQuantizer)
|
||||
throws IOException {
|
||||
// write vector values
|
||||
long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES);
|
||||
writeQuantizedVectors(fieldData);
|
||||
writeQuantizedVectors(fieldData, scalarQuantizer);
|
||||
long vectorDataLength = quantizedVectorData.getFilePointer() - vectorDataOffset;
|
||||
|
||||
writeMeta(
|
||||
|
@ -318,9 +322,9 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
confidenceInterval,
|
||||
bits,
|
||||
compress,
|
||||
fieldData.minQuantile,
|
||||
fieldData.maxQuantile,
|
||||
fieldData.docsWithField);
|
||||
scalarQuantizer.getLowerQuantile(),
|
||||
scalarQuantizer.getUpperQuantile(),
|
||||
fieldData.getDocsWithFieldSet());
|
||||
}
|
||||
|
||||
private void writeMeta(
|
||||
|
@ -365,8 +369,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
DIRECT_MONOTONIC_BLOCK_SHIFT, meta, quantizedVectorData, count, maxDoc, docsWithField);
|
||||
}
|
||||
|
||||
private void writeQuantizedVectors(FieldWriter fieldData) throws IOException {
|
||||
ScalarQuantizer scalarQuantizer = fieldData.createQuantizer();
|
||||
private void writeQuantizedVectors(FieldWriter fieldData, ScalarQuantizer scalarQuantizer)
|
||||
throws IOException {
|
||||
byte[] vector = new byte[fieldData.fieldInfo.getVectorDimension()];
|
||||
byte[] compressedVector =
|
||||
fieldData.compress
|
||||
|
@ -375,7 +379,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
: null;
|
||||
final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||
float[] copy = fieldData.normalize ? new float[fieldData.fieldInfo.getVectorDimension()] : null;
|
||||
for (float[] v : fieldData.floatVectors) {
|
||||
assert fieldData.getVectors().isEmpty() || scalarQuantizer != null;
|
||||
for (float[] v : fieldData.getVectors()) {
|
||||
if (fieldData.normalize) {
|
||||
System.arraycopy(v, 0, copy, 0, copy.length);
|
||||
VectorUtil.l2normalize(copy);
|
||||
|
@ -396,16 +401,18 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
}
|
||||
}
|
||||
|
||||
private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocMap sortMap)
|
||||
private void writeSortingField(
|
||||
FieldWriter fieldData, int maxDoc, Sorter.DocMap sortMap, ScalarQuantizer scalarQuantizer)
|
||||
throws IOException {
|
||||
final int[] ordMap = new int[fieldData.docsWithField.cardinality()]; // new ord to old ord
|
||||
final int[] ordMap =
|
||||
new int[fieldData.getDocsWithFieldSet().cardinality()]; // new ord to old ord
|
||||
|
||||
DocsWithFieldSet newDocsWithField = new DocsWithFieldSet();
|
||||
mapOldOrdToNewOrd(fieldData.docsWithField, sortMap, null, ordMap, newDocsWithField);
|
||||
mapOldOrdToNewOrd(fieldData.getDocsWithFieldSet(), sortMap, null, ordMap, newDocsWithField);
|
||||
|
||||
// write vector values
|
||||
long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES);
|
||||
writeSortedQuantizedVectors(fieldData, ordMap);
|
||||
writeSortedQuantizedVectors(fieldData, ordMap, scalarQuantizer);
|
||||
long quantizedVectorLength = quantizedVectorData.getFilePointer() - vectorDataOffset;
|
||||
writeMeta(
|
||||
fieldData.fieldInfo,
|
||||
|
@ -415,13 +422,13 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
confidenceInterval,
|
||||
bits,
|
||||
compress,
|
||||
fieldData.minQuantile,
|
||||
fieldData.maxQuantile,
|
||||
scalarQuantizer.getLowerQuantile(),
|
||||
scalarQuantizer.getUpperQuantile(),
|
||||
newDocsWithField);
|
||||
}
|
||||
|
||||
private void writeSortedQuantizedVectors(FieldWriter fieldData, int[] ordMap) throws IOException {
|
||||
ScalarQuantizer scalarQuantizer = fieldData.createQuantizer();
|
||||
private void writeSortedQuantizedVectors(
|
||||
FieldWriter fieldData, int[] ordMap, ScalarQuantizer scalarQuantizer) throws IOException {
|
||||
byte[] vector = new byte[fieldData.fieldInfo.getVectorDimension()];
|
||||
byte[] compressedVector =
|
||||
fieldData.compress
|
||||
|
@ -431,7 +438,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||
float[] copy = fieldData.normalize ? new float[fieldData.fieldInfo.getVectorDimension()] : null;
|
||||
for (int ordinal : ordMap) {
|
||||
float[] v = fieldData.floatVectors.get(ordinal);
|
||||
float[] v = fieldData.getVectors().get(ordinal);
|
||||
if (fieldData.normalize) {
|
||||
System.arraycopy(v, 0, copy, 0, copy.length);
|
||||
VectorUtil.l2normalize(copy);
|
||||
|
@ -744,44 +751,51 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
|
||||
static class FieldWriter extends FlatFieldVectorsWriter<float[]> {
|
||||
private static final long SHALLOW_SIZE = shallowSizeOfInstance(FieldWriter.class);
|
||||
private final List<float[]> floatVectors;
|
||||
private final FieldInfo fieldInfo;
|
||||
private final Float confidenceInterval;
|
||||
private final byte bits;
|
||||
private final boolean compress;
|
||||
private final InfoStream infoStream;
|
||||
private final boolean normalize;
|
||||
private float minQuantile = Float.POSITIVE_INFINITY;
|
||||
private float maxQuantile = Float.NEGATIVE_INFINITY;
|
||||
private boolean finished;
|
||||
private final DocsWithFieldSet docsWithField;
|
||||
private final FlatFieldVectorsWriter<float[]> flatFieldVectorsWriter;
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
FieldWriter(
|
||||
Float confidenceInterval,
|
||||
byte bits,
|
||||
boolean compress,
|
||||
FieldInfo fieldInfo,
|
||||
InfoStream infoStream,
|
||||
KnnFieldVectorsWriter<?> indexWriter) {
|
||||
super((KnnFieldVectorsWriter<float[]>) indexWriter);
|
||||
FlatFieldVectorsWriter<float[]> indexWriter) {
|
||||
super();
|
||||
this.confidenceInterval = confidenceInterval;
|
||||
this.bits = bits;
|
||||
this.fieldInfo = fieldInfo;
|
||||
this.normalize = fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE;
|
||||
this.floatVectors = new ArrayList<>();
|
||||
this.infoStream = infoStream;
|
||||
this.docsWithField = new DocsWithFieldSet();
|
||||
this.compress = compress;
|
||||
this.flatFieldVectorsWriter = Objects.requireNonNull(indexWriter);
|
||||
}
|
||||
|
||||
void finish() throws IOException {
|
||||
@Override
|
||||
public boolean isFinished() {
|
||||
return finished && flatFieldVectorsWriter.isFinished();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void finish() throws IOException {
|
||||
if (finished) {
|
||||
return;
|
||||
}
|
||||
assert flatFieldVectorsWriter.isFinished();
|
||||
finished = true;
|
||||
}
|
||||
|
||||
ScalarQuantizer createQuantizer() throws IOException {
|
||||
assert flatFieldVectorsWriter.isFinished();
|
||||
List<float[]> floatVectors = flatFieldVectorsWriter.getVectors();
|
||||
if (floatVectors.size() == 0) {
|
||||
finished = true;
|
||||
return;
|
||||
return new ScalarQuantizer(0, 0, bits);
|
||||
}
|
||||
FloatVectorValues floatVectorValues = new FloatVectorWrapper(floatVectors, normalize);
|
||||
ScalarQuantizer quantizer =
|
||||
|
@ -791,8 +805,6 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
fieldInfo.getVectorSimilarityFunction(),
|
||||
confidenceInterval,
|
||||
bits);
|
||||
minQuantile = quantizer.getLowerQuantile();
|
||||
maxQuantile = quantizer.getUpperQuantile();
|
||||
if (infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) {
|
||||
infoStream.message(
|
||||
QUANTIZED_VECTOR_COMPONENT,
|
||||
|
@ -802,41 +814,39 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
+ " bits="
|
||||
+ bits
|
||||
+ " minQuantile="
|
||||
+ minQuantile
|
||||
+ quantizer.getLowerQuantile()
|
||||
+ " maxQuantile="
|
||||
+ maxQuantile);
|
||||
+ quantizer.getUpperQuantile());
|
||||
}
|
||||
finished = true;
|
||||
}
|
||||
|
||||
ScalarQuantizer createQuantizer() {
|
||||
assert finished;
|
||||
return new ScalarQuantizer(minQuantile, maxQuantile, bits);
|
||||
return quantizer;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
long size = SHALLOW_SIZE;
|
||||
if (indexingDelegate != null) {
|
||||
size += indexingDelegate.ramBytesUsed();
|
||||
}
|
||||
if (floatVectors.size() == 0) return size;
|
||||
return size + (long) floatVectors.size() * RamUsageEstimator.NUM_BYTES_OBJECT_REF;
|
||||
size += flatFieldVectorsWriter.ramBytesUsed();
|
||||
return size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addValue(int docID, float[] vectorValue) throws IOException {
|
||||
docsWithField.add(docID);
|
||||
floatVectors.add(vectorValue);
|
||||
if (indexingDelegate != null) {
|
||||
indexingDelegate.addValue(docID, vectorValue);
|
||||
}
|
||||
flatFieldVectorsWriter.addValue(docID, vectorValue);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] copyValue(float[] vectorValue) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<float[]> getVectors() {
|
||||
return flatFieldVectorsWriter.getVectors();
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocsWithFieldSet getDocsWithFieldSet() {
|
||||
return flatFieldVectorsWriter.getDocsWithFieldSet();
|
||||
}
|
||||
}
|
||||
|
||||
static class FloatVectorWrapper extends FloatVectorValues {
|
||||
|
|
Loading…
Reference in New Issue