Replace Map<String,Object> with IntObjectHashMap for KnnVectorsReader (#13763)

This commit is contained in:
panguixin 2024-10-31 23:16:09 +08:00 committed by GitHub
parent 26e0737e40
commit 494b16063e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 218 additions and 205 deletions

View File

@ -82,6 +82,8 @@ Optimizations
* GITHUB#13958: Speed up advancing within a block. (Adrien Grand)
* GITHUB#13763: Replace Map<String,Object> with IntObjectHashMap for KnnVectorsReader (Pan Guixin)
Bug Fixes
---------------------
* GITHUB#13832: Fixed an issue where the DefaultPassageFormatter.format method did not format passages as intended

View File

@ -20,8 +20,6 @@ package org.apache.lucene.backward_codecs.lucene90;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.SplittableRandom;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
@ -33,6 +31,7 @@ import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.ChecksumIndexInput;
@ -50,14 +49,16 @@ import org.apache.lucene.util.hnsw.NeighborQueue;
*/
public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
private final Map<String, FieldEntry> fields = new HashMap<>();
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
private final IndexInput vectorData;
private final IndexInput vectorIndex;
private final long checksumSeed;
private final FieldInfos fieldInfos;
Lucene90HnswVectorsReader(SegmentReadState state) throws IOException {
int versionMeta = readMetadata(state);
long[] checksumRef = new long[1];
this.fieldInfos = state.fieldInfos;
boolean success = false;
try {
vectorData =
@ -158,7 +159,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
FieldEntry fieldEntry = readField(meta, info);
validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry);
fields.put(info.number, fieldEntry);
}
}
@ -218,13 +219,18 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
CodecUtil.checksumEntireFile(vectorIndex);
}
@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
private FieldEntry getFieldEntry(String field) {
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry fieldEntry;
if (info == null || (fieldEntry = fields.get(info.number)) == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
return getOffHeapVectorValues(fieldEntry);
return fieldEntry;
}
@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
return getOffHeapVectorValues(getFieldEntry(field));
}
@Override
@ -235,8 +241,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);
final FieldEntry fieldEntry = getFieldEntry(field);
if (fieldEntry.size() == 0) {
return;
}

View File

@ -21,8 +21,6 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.function.IntUnaryOperator;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
@ -35,6 +33,7 @@ import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer;
@ -55,13 +54,15 @@ import org.apache.lucene.util.hnsw.RandomVectorScorer;
*/
public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
private final Map<String, FieldEntry> fields = new HashMap<>();
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
private final IndexInput vectorData;
private final IndexInput vectorIndex;
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
private final FieldInfos fieldInfos;
Lucene91HnswVectorsReader(SegmentReadState state) throws IOException {
int versionMeta = readMetadata(state);
this.fieldInfos = state.fieldInfos;
boolean success = false;
try {
vectorData =
@ -154,7 +155,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
}
FieldEntry fieldEntry = readField(meta, info);
validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry);
fields.put(info.number, fieldEntry);
}
}
@ -214,13 +215,18 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
CodecUtil.checksumEntireFile(vectorIndex);
}
@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
private FieldEntry getFieldEntry(String field) {
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry fieldEntry;
if (info == null || (fieldEntry = fields.get(info.number)) == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
return getOffHeapVectorValues(fieldEntry);
return fieldEntry;
}
@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
return getOffHeapVectorValues(getFieldEntry(field));
}
@Override
@ -231,8 +237,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);
final FieldEntry fieldEntry = getFieldEntry(field);
if (fieldEntry.size() == 0) {
return;
}

View File

@ -21,8 +21,6 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
@ -34,6 +32,7 @@ import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
@ -53,13 +52,15 @@ import org.apache.lucene.util.packed.DirectMonotonicReader;
*/
public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
private final Map<String, FieldEntry> fields = new HashMap<>();
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
private final IndexInput vectorData;
private final IndexInput vectorIndex;
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
private final FieldInfos fieldInfos;
Lucene92HnswVectorsReader(SegmentReadState state) throws IOException {
int versionMeta = readMetadata(state);
this.fieldInfos = state.fieldInfos;
boolean success = false;
try {
vectorData =
@ -152,7 +153,7 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
}
FieldEntry fieldEntry = readField(meta, info);
validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry);
fields.put(info.number, fieldEntry);
}
}
@ -212,13 +213,18 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
CodecUtil.checksumEntireFile(vectorIndex);
}
@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
private FieldEntry getFieldEntry(String field) {
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry fieldEntry;
if (info == null || (fieldEntry = fields.get(info.number)) == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
return OffHeapFloatVectorValues.load(fieldEntry, vectorData);
return fieldEntry;
}
@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
return OffHeapFloatVectorValues.load(getFieldEntry(field), vectorData);
}
@Override
@ -229,8 +235,7 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);
final FieldEntry fieldEntry = getFieldEntry(field);
if (fieldEntry.size() == 0) {
return;
}

View File

@ -21,8 +21,6 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
@ -35,6 +33,7 @@ import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
@ -54,13 +53,15 @@ import org.apache.lucene.util.packed.DirectMonotonicReader;
*/
public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
private final Map<String, FieldEntry> fields = new HashMap<>();
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
private final IndexInput vectorData;
private final IndexInput vectorIndex;
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
private final FieldInfos fieldInfos;
Lucene94HnswVectorsReader(SegmentReadState state) throws IOException {
int versionMeta = readMetadata(state);
this.fieldInfos = state.fieldInfos;
boolean success = false;
try {
vectorData =
@ -153,7 +154,7 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
}
FieldEntry fieldEntry = readField(meta, info);
validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry);
fields.put(info.number, fieldEntry);
}
}
@ -230,48 +231,41 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
CodecUtil.checksumEntireFile(vectorIndex);
}
@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
private FieldEntry getFieldEntry(String field, VectorEncoding expectedEncoding) {
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry fieldEntry;
if (info == null || (fieldEntry = fields.get(info.number)) == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
if (fieldEntry.vectorEncoding != expectedEncoding) {
throw new IllegalArgumentException(
"field=\""
+ field
+ "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ VectorEncoding.FLOAT32);
+ expectedEncoding);
}
return fieldEntry;
}
@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
return OffHeapFloatVectorValues.load(fieldEntry, vectorData);
}
@Override
public ByteVectorValues getByteVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
throw new IllegalArgumentException(
"field=\""
+ field
+ "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ VectorEncoding.BYTE);
}
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
return OffHeapByteVectorValues.load(fieldEntry, vectorData);
}
@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry.size() == 0 || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
if (fieldEntry.size() == 0 || knnCollector.k() == 0) {
return;
}
@ -289,9 +283,8 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry.size() == 0 || fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
if (fieldEntry.size() == 0 || knnCollector.k() == 0) {
return;
}

View File

@ -21,8 +21,6 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
@ -39,6 +37,7 @@ import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
@ -61,7 +60,7 @@ import org.apache.lucene.util.packed.DirectMonotonicReader;
public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements HnswGraphProvider {
private final FieldInfos fieldInfos;
private final Map<String, FieldEntry> fields = new HashMap<>();
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
private final IndexInput vectorData;
private final IndexInput vectorIndex;
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
@ -161,7 +160,7 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
}
FieldEntry fieldEntry = readField(meta, info);
validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry);
fields.put(info.number, fieldEntry);
}
}
@ -238,21 +237,27 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
CodecUtil.checksumEntireFile(vectorIndex);
}
@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
private FieldEntry getFieldEntry(String field, VectorEncoding expectedEncoding) {
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry fieldEntry;
if (info == null || (fieldEntry = fields.get(info.number)) == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
if (fieldEntry.vectorEncoding != expectedEncoding) {
throw new IllegalArgumentException(
"field=\""
+ field
+ "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ VectorEncoding.FLOAT32);
+ expectedEncoding);
}
return fieldEntry;
}
@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
return OffHeapFloatVectorValues.load(
fieldEntry.similarityFunction,
defaultFlatVectorScorer,
@ -266,19 +271,7 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
@Override
public ByteVectorValues getByteVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
throw new IllegalArgumentException(
"field=\""
+ field
+ "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ VectorEncoding.BYTE);
}
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
return OffHeapByteVectorValues.load(
fieldEntry.similarityFunction,
defaultFlatVectorScorer,
@ -293,11 +286,8 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry.size() == 0
|| knnCollector.k() == 0
|| fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
if (fieldEntry.size() == 0 || knnCollector.k() == 0) {
return;
}
@ -324,11 +314,8 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry.size() == 0
|| knnCollector.k() == 0
|| fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
if (fieldEntry.size() == 0 || knnCollector.k() == 0) {
return;
}
@ -355,12 +342,12 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
/** Get knn graph values; used for testing */
@Override
public HnswGraph getGraph(String field) throws IOException {
FieldInfo info = fieldInfos.fieldInfo(field);
if (info == null) {
throw new IllegalArgumentException("No such field '" + field + "'");
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry entry;
if (info == null || (entry = fields.get(info.number)) == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
FieldEntry entry = fields.get(field);
if (entry != null && entry.vectorIndexLength > 0) {
if (entry.vectorIndexLength > 0) {
return getGraph(entry);
} else {
return HnswGraph.EMPTY;

View File

@ -26,8 +26,6 @@ import static org.apache.lucene.codecs.simpletext.SimpleTextKnnVectorsWriter.VEC
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CorruptIndexException;
@ -36,6 +34,7 @@ import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer;
@ -63,7 +62,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
private final SegmentReadState readState;
private final IndexInput dataIn;
private final BytesRefBuilder scratch = new BytesRefBuilder();
private final Map<String, FieldEntry> fieldEntries = new HashMap<>();
private final IntObjectHashMap<FieldEntry> fieldEntries = new IntObjectHashMap<>();
SimpleTextKnnVectorsReader(SegmentReadState readState) throws IOException {
this.readState = readState;
@ -91,9 +90,9 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
for (int i = 0; i < size; i++) {
docIds[i] = readInt(in, EMPTY);
}
assert fieldEntries.containsKey(fieldName) == false;
assert fieldEntries.containsKey(fieldNumber) == false;
fieldEntries.put(
fieldName,
fieldNumber,
new FieldEntry(
dimension,
vectorDataOffset,
@ -126,7 +125,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
throw new IllegalStateException(
"KNN vectors readers should not be called on fields that don't enable KNN vectors");
}
FieldEntry fieldEntry = fieldEntries.get(field);
FieldEntry fieldEntry = fieldEntries.get(info.number);
if (fieldEntry == null) {
// mirror the handling in Lucene90VectorReader#getVectorValues
// needed to pass TestSimpleTextKnnVectorsFormat#testDeleteAllVectorDocs
@ -159,7 +158,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
throw new IllegalStateException(
"KNN vectors readers should not be called on fields that don't enable KNN vectors");
}
FieldEntry fieldEntry = fieldEntries.get(field);
FieldEntry fieldEntry = fieldEntries.get(info.number);
if (fieldEntry == null) {
// mirror the handling in Lucene90VectorReader#getVectorValues
// needed to pass TestSimpleTextKnnVectorsFormat#testDeleteAllVectorDocs

View File

@ -21,8 +21,6 @@ import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSi
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
@ -38,6 +36,7 @@ import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
@ -56,13 +55,15 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
private static final long SHALLOW_SIZE =
RamUsageEstimator.shallowSizeOfInstance(Lucene99FlatVectorsFormat.class);
private final Map<String, FieldEntry> fields = new HashMap<>();
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
private final IndexInput vectorData;
private final FieldInfos fieldInfos;
public Lucene99FlatVectorsReader(SegmentReadState state, FlatVectorsScorer scorer)
throws IOException {
super(scorer);
int versionMeta = readMetadata(state);
this.fieldInfos = state.fieldInfos;
boolean success = false;
try {
vectorData =
@ -155,15 +156,13 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
}
FieldEntry fieldEntry = FieldEntry.create(meta, info);
fields.put(info.name, fieldEntry);
fields.put(info.number, fieldEntry);
}
}
@Override
public long ramBytesUsed() {
return Lucene99FlatVectorsReader.SHALLOW_SIZE
+ RamUsageEstimator.sizeOfMap(
fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class));
return Lucene99FlatVectorsReader.SHALLOW_SIZE + fields.ramBytesUsed();
}
@Override
@ -171,21 +170,27 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
CodecUtil.checksumEntireFile(vectorData);
}
@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
private FieldEntry getFieldEntry(String field, VectorEncoding expectedEncoding) {
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry fieldEntry;
if (info == null || (fieldEntry = fields.get(info.number)) == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
if (fieldEntry.vectorEncoding != expectedEncoding) {
throw new IllegalArgumentException(
"field=\""
+ field
+ "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ VectorEncoding.FLOAT32);
+ expectedEncoding);
}
return fieldEntry;
}
@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
return OffHeapFloatVectorValues.load(
fieldEntry.similarityFunction,
vectorScorer,
@ -199,19 +204,7 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
@Override
public ByteVectorValues getByteVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
throw new IllegalArgumentException(
"field=\""
+ field
+ "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ VectorEncoding.BYTE);
}
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
return OffHeapByteVectorValues.load(
fieldEntry.similarityFunction,
vectorScorer,
@ -225,10 +218,7 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
@Override
public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return null;
}
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
return vectorScorer.getRandomVectorScorer(
fieldEntry.similarityFunction,
OffHeapFloatVectorValues.load(
@ -245,10 +235,7 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
@Override
public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
return null;
}
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
return vectorScorer.getRandomVectorScorer(
fieldEntry.similarityFunction,
OffHeapByteVectorValues.load(

View File

@ -21,9 +21,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
@ -37,6 +35,7 @@ import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
@ -70,7 +69,7 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
private final FlatVectorsReader flatVectorsReader;
private final FieldInfos fieldInfos;
private final Map<String, FieldEntry> fields = new HashMap<>();
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
private final IndexInput vectorIndex;
public Lucene99HnswVectorsReader(SegmentReadState state, FlatVectorsReader flatVectorsReader)
@ -162,7 +161,7 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
}
FieldEntry fieldEntry = readField(meta, info);
validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry);
fields.put(info.number, fieldEntry);
}
}
@ -225,8 +224,7 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
@Override
public long ramBytesUsed() {
return Lucene99HnswVectorsReader.SHALLOW_SIZE
+ RamUsageEstimator.sizeOfMap(
fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class))
+ fields.ramBytesUsed()
+ flatVectorsReader.ramBytesUsed();
}
@ -246,25 +244,43 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
return flatVectorsReader.getByteVectorValues(field);
}
private FieldEntry getFieldEntry(String field, VectorEncoding expectedEncoding) {
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry fieldEntry;
if (info == null || (fieldEntry = fields.get(info.number)) == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
if (fieldEntry.vectorEncoding != expectedEncoding) {
throw new IllegalArgumentException(
"field=\""
+ field
+ "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ expectedEncoding);
}
return fieldEntry;
}
@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
search(
fields.get(field),
fieldEntry,
knnCollector,
acceptDocs,
VectorEncoding.FLOAT32,
() -> flatVectorsReader.getRandomVectorScorer(field, target));
}
@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
search(
fields.get(field),
fieldEntry,
knnCollector,
acceptDocs,
VectorEncoding.BYTE,
() -> flatVectorsReader.getRandomVectorScorer(field, target));
}
@ -272,13 +288,10 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
FieldEntry fieldEntry,
KnnCollector knnCollector,
Bits acceptDocs,
VectorEncoding vectorEncoding,
IOSupplier<RandomVectorScorer> scorerSupplier)
throws IOException {
if (fieldEntry.size() == 0
|| knnCollector.k() == 0
|| fieldEntry.vectorEncoding != vectorEncoding) {
if (fieldEntry.size() == 0 || knnCollector.k() == 0) {
return;
}
final RandomVectorScorer scorer = scorerSupplier.get();
@ -304,12 +317,12 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
@Override
public HnswGraph getGraph(String field) throws IOException {
FieldInfo info = fieldInfos.fieldInfo(field);
if (info == null) {
throw new IllegalArgumentException("No such field '" + field + "'");
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry entry;
if (info == null || (entry = fields.get(info.number)) == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
FieldEntry entry = fields.get(field);
if (entry != null && entry.vectorIndexLength > 0) {
if (entry.vectorIndexLength > 0) {
return getGraph(entry);
} else {
return HnswGraph.EMPTY;

View File

@ -21,8 +21,6 @@ import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSi
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
@ -36,6 +34,7 @@ import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.IOContext;
@ -59,15 +58,17 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
private static final long SHALLOW_SIZE =
RamUsageEstimator.shallowSizeOfInstance(Lucene99ScalarQuantizedVectorsReader.class);
private final Map<String, FieldEntry> fields = new HashMap<>();
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
private final IndexInput quantizedVectorData;
private final FlatVectorsReader rawVectorsReader;
private final FieldInfos fieldInfos;
public Lucene99ScalarQuantizedVectorsReader(
SegmentReadState state, FlatVectorsReader rawVectorsReader, FlatVectorsScorer scorer)
throws IOException {
super(scorer);
this.rawVectorsReader = rawVectorsReader;
this.fieldInfos = state.fieldInfos;
int versionMeta = -1;
String metaFileName =
IndexFileNames.segmentFileName(
@ -118,7 +119,7 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
}
FieldEntry fieldEntry = readField(meta, versionMeta, info);
validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry);
fields.put(info.number, fieldEntry);
}
}
@ -163,10 +164,10 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
CodecUtil.checksumEntireFile(quantizedVectorData);
}
@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
private FieldEntry getFieldEntry(String field) {
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry fieldEntry;
if (info == null || (fieldEntry = fields.get(info.number)) == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
@ -178,6 +179,12 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
+ " expected: "
+ VectorEncoding.FLOAT32);
}
return fieldEntry;
}
@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
final FieldEntry fieldEntry = getFieldEntry(field);
final FloatVectorValues rawVectorValues = rawVectorsReader.getFloatVectorValues(field);
OffHeapQuantizedByteVectorValues quantizedByteVectorValues =
OffHeapQuantizedByteVectorValues.load(
@ -241,10 +248,7 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
@Override
public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return null;
}
final FieldEntry fieldEntry = getFieldEntry(field);
if (fieldEntry.scalarQuantizer == null) {
return rawVectorsReader.getRandomVectorScorer(field, target);
}
@ -275,12 +279,7 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
@Override
public long ramBytesUsed() {
long size = SHALLOW_SIZE;
size +=
RamUsageEstimator.sizeOfMap(
fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class));
size += rawVectorsReader.ramBytesUsed();
return size;
return SHALLOW_SIZE + fields.ramBytesUsed() + rawVectorsReader.ramBytesUsed();
}
private FieldEntry readField(IndexInput input, int versionMeta, FieldInfo info)
@ -301,11 +300,8 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
}
@Override
public QuantizedByteVectorValues getQuantizedVectorValues(String fieldName) throws IOException {
FieldEntry fieldEntry = fields.get(fieldName);
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return null;
}
public QuantizedByteVectorValues getQuantizedVectorValues(String field) throws IOException {
final FieldEntry fieldEntry = getFieldEntry(field);
return OffHeapQuantizedByteVectorValues.load(
fieldEntry.ordToDoc,
fieldEntry.dimension,
@ -320,11 +316,8 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
}
@Override
public ScalarQuantizer getQuantizationState(String fieldName) {
FieldEntry fieldEntry = fields.get(fieldName);
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return null;
}
public ScalarQuantizer getQuantizationState(String field) {
final FieldEntry fieldEntry = getFieldEntry(field);
return fieldEntry.scalarQuantizer;
}

View File

@ -19,7 +19,9 @@ package org.apache.lucene.codecs.perfield;
import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.ServiceLoader;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
@ -28,11 +30,14 @@ import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.internal.hppc.ObjectCursor;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.IOUtils;
@ -186,7 +191,8 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
/** VectorReader that can wrap multiple delegate readers, selected by field. */
public static class FieldsReader extends KnnVectorsReader {
private final Map<String, KnnVectorsReader> fields = new HashMap<>();
private final IntObjectHashMap<KnnVectorsReader> fields = new IntObjectHashMap<>();
private final FieldInfos fieldInfos;
/**
* Create a FieldsReader over a segment, opening VectorReaders for each KnnVectorsFormat
@ -196,7 +202,7 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
* @throws IOException if one of the delegate readers throws
*/
public FieldsReader(final SegmentReadState readState) throws IOException {
this.fieldInfos = readState.fieldInfos;
// Init each unique format:
boolean success = false;
Map<String, KnnVectorsReader> formats = new HashMap<>();
@ -221,7 +227,7 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
segmentSuffix,
format.fieldsReader(new SegmentReadState(readState, segmentSuffix)));
}
fields.put(fieldName, formats.get(segmentSuffix));
fields.put(fi.number, formats.get(segmentSuffix));
}
}
}
@ -239,51 +245,69 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
* @param field the name of a numeric vector field
*/
public KnnVectorsReader getFieldReader(String field) {
return fields.get(field);
final FieldInfo info = fieldInfos.fieldInfo(field);
if (info == null) {
return null;
}
return fields.get(info.number);
}
@Override
public void checkIntegrity() throws IOException {
for (KnnVectorsReader reader : fields.values()) {
reader.checkIntegrity();
for (ObjectCursor<KnnVectorsReader> cursor : fields.values()) {
cursor.value.checkIntegrity();
}
}
@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
KnnVectorsReader knnVectorsReader = fields.get(field);
if (knnVectorsReader == null) {
final FieldInfo info = fieldInfos.fieldInfo(field);
final KnnVectorsReader reader;
if (info == null || (reader = fields.get(info.number)) == null) {
return null;
} else {
return knnVectorsReader.getFloatVectorValues(field);
}
return reader.getFloatVectorValues(field);
}
@Override
public ByteVectorValues getByteVectorValues(String field) throws IOException {
KnnVectorsReader knnVectorsReader = fields.get(field);
if (knnVectorsReader == null) {
final FieldInfo info = fieldInfos.fieldInfo(field);
final KnnVectorsReader reader;
if (info == null || (reader = fields.get(info.number)) == null) {
return null;
} else {
return knnVectorsReader.getByteVectorValues(field);
}
return reader.getByteVectorValues(field);
}
@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
fields.get(field).search(field, target, knnCollector, acceptDocs);
final FieldInfo info = fieldInfos.fieldInfo(field);
final KnnVectorsReader reader;
if (info == null || (reader = fields.get(info.number)) == null) {
return;
}
reader.search(field, target, knnCollector, acceptDocs);
}
@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
fields.get(field).search(field, target, knnCollector, acceptDocs);
final FieldInfo info = fieldInfos.fieldInfo(field);
final KnnVectorsReader reader;
if (info == null || (reader = fields.get(info.number)) == null) {
return;
}
reader.search(field, target, knnCollector, acceptDocs);
}
@Override
public void close() throws IOException {
IOUtils.close(fields.values());
List<KnnVectorsReader> readers = new ArrayList<>(fields.size());
for (ObjectCursor<KnnVectorsReader> cursor : fields.values()) {
readers.add(cursor.value);
}
IOUtils.close(readers);
}
}