Fix quantized vector writer ram estimates (#13553)

* Fix quantized vector writer ram estimates

* add test & changes
This commit is contained in:
Benjamin Trent 2024-07-09 13:00:10 -04:00 committed by GitHub
parent 295c5d3576
commit 9bfde5514c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 82 additions and 6 deletions

View File

@ -297,6 +297,9 @@ Bug Fixes
* GITHUB#13463: Address bug in MultiLeafKnnCollector causing #minCompetitiveSimilarity to stay artificially low in
some corner cases. (Greg Miller)
* GITHUB#13553: Correct RamUsageEstimate for scalar quantized knn vector formats so that raw vectors are correctly
accounted for. (Ben Trent)
Other
--------------------
(No changes)

View File

@ -171,10 +171,8 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
@Override
public long ramBytesUsed() {
long total = SHALLOW_RAM_BYTES_USED;
// The vector delegate will also account for this writer's KnnFieldVectorsWriter objects
total += flatVectorWriter.ramBytesUsed();
for (FieldWriter<?> field : fields) {
total += field.ramBytesUsed();
}
return total;
}

View File

@ -299,9 +299,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
@Override
public long ramBytesUsed() {
long total = SHALLOW_RAM_BYTES_USED;
for (FieldWriter field : fields) {
total += field.ramBytesUsed();
}
// The vector delegate will also account for this writer's KnnFieldVectorsWriter objects
total += rawVectorDelegate.ramBytesUsed();
return total;
}

View File

@ -23,10 +23,15 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnByteVectorField;
@ -38,13 +43,19 @@ import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CheckIndex;
import org.apache.lucene.index.CodecReader;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NoMergePolicy;
import org.apache.lucene.index.SegmentInfo;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.StoredFields;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.VectorEncoding;
@ -60,7 +71,10 @@ import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.StringHelper;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.Version;
import org.junit.Before;
/**
@ -216,6 +230,68 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
}
}
@SuppressWarnings("unchecked")
public void testWriterRamEstimate() throws Exception {
final FieldInfos fieldInfos = new FieldInfos(new FieldInfo[0]);
final Directory dir = newDirectory();
Codec codec = Codec.getDefault();
final SegmentInfo si =
new SegmentInfo(
dir,
Version.LATEST,
Version.LATEST,
"0",
10000,
false,
false,
codec,
Collections.emptyMap(),
StringHelper.randomId(),
new HashMap<>(),
null);
final SegmentWriteState state =
new SegmentWriteState(
InfoStream.getDefault(), dir, si, fieldInfos, null, newIOContext(random()));
final KnnVectorsFormat format = codec.knnVectorsFormat();
try (KnnVectorsWriter writer = format.fieldsWriter(state)) {
final long ramBytesUsed = writer.ramBytesUsed();
int dim = random().nextInt(64) + 1;
if (dim % 2 == 1) {
++dim;
}
int numDocs = atLeast(100);
KnnFieldVectorsWriter<float[]> fieldWriter =
(KnnFieldVectorsWriter<float[]>)
writer.addField(
new FieldInfo(
"fieldA",
0,
false,
false,
false,
IndexOptions.NONE,
DocValuesType.NONE,
false,
-1,
Map.of(),
0,
0,
0,
dim,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.DOT_PRODUCT,
false,
false));
for (int i = 0; i < numDocs; i++) {
fieldWriter.addValue(i, randomVector(dim));
}
final long ramBytesUsed2 = writer.ramBytesUsed();
assertTrue(ramBytesUsed2 > ramBytesUsed);
assertTrue(ramBytesUsed2 > (long) dim * numDocs * Float.BYTES);
}
dir.close();
}
public void testIllegalSimilarityFunctionChangeTwoWriters() throws Exception {
try (Directory dir = newDirectory()) {
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {