Replace Map<String,Object> with IntObjectHashMap for KnnVectorsReader (#13763)

This commit is contained in:
panguixin 2024-10-31 23:16:09 +08:00 committed by Adrien Grand
parent cff28d546f
commit 584387a254
11 changed files with 218 additions and 205 deletions

View File

@ -56,6 +56,8 @@ Optimizations
* GITHUB#13958: Speed up advancing within a block. (Adrien Grand) * GITHUB#13958: Speed up advancing within a block. (Adrien Grand)
* GITHUB#13763: Replace Map<String,Object> with IntObjectHashMap for KnnVectorsReader (Pan Guixin)
Bug Fixes Bug Fixes
--------------------- ---------------------
* GITHUB#13832: Fixed an issue where the DefaultPassageFormatter.format method did not format passages as intended * GITHUB#13832: Fixed an issue where the DefaultPassageFormatter.format method did not format passages as intended

View File

@ -20,8 +20,6 @@ package org.apache.lucene.backward_codecs.lucene90;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException; import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.SplittableRandom; import java.util.SplittableRandom;
import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsReader;
@ -33,6 +31,7 @@ import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer; import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.ChecksumIndexInput;
@ -50,14 +49,16 @@ import org.apache.lucene.util.hnsw.NeighborQueue;
*/ */
public final class Lucene90HnswVectorsReader extends KnnVectorsReader { public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
private final Map<String, FieldEntry> fields = new HashMap<>(); private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
private final IndexInput vectorData; private final IndexInput vectorData;
private final IndexInput vectorIndex; private final IndexInput vectorIndex;
private final long checksumSeed; private final long checksumSeed;
private final FieldInfos fieldInfos;
Lucene90HnswVectorsReader(SegmentReadState state) throws IOException { Lucene90HnswVectorsReader(SegmentReadState state) throws IOException {
int versionMeta = readMetadata(state); int versionMeta = readMetadata(state);
long[] checksumRef = new long[1]; long[] checksumRef = new long[1];
this.fieldInfos = state.fieldInfos;
boolean success = false; boolean success = false;
try { try {
vectorData = vectorData =
@ -158,7 +159,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
FieldEntry fieldEntry = readField(meta, info); FieldEntry fieldEntry = readField(meta, info);
validateFieldEntry(info, fieldEntry); validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry); fields.put(info.number, fieldEntry);
} }
} }
@ -218,13 +219,18 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
CodecUtil.checksumEntireFile(vectorIndex); CodecUtil.checksumEntireFile(vectorIndex);
} }
@Override private FieldEntry getFieldEntry(String field) {
public FloatVectorValues getFloatVectorValues(String field) throws IOException { final FieldInfo info = fieldInfos.fieldInfo(field);
FieldEntry fieldEntry = fields.get(field); final FieldEntry fieldEntry;
if (fieldEntry == null) { if (info == null || (fieldEntry = fields.get(info.number)) == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found"); throw new IllegalArgumentException("field=\"" + field + "\" not found");
} }
return getOffHeapVectorValues(fieldEntry); return fieldEntry;
}
@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
return getOffHeapVectorValues(getFieldEntry(field));
} }
@Override @Override
@ -235,8 +241,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
@Override @Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException { throws IOException {
FieldEntry fieldEntry = fields.get(field); final FieldEntry fieldEntry = getFieldEntry(field);
if (fieldEntry.size() == 0) { if (fieldEntry.size() == 0) {
return; return;
} }

View File

@ -21,8 +21,6 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.function.IntUnaryOperator; import java.util.function.IntUnaryOperator;
import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsReader;
@ -35,6 +33,7 @@ import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer; import org.apache.lucene.search.VectorScorer;
@ -55,13 +54,15 @@ import org.apache.lucene.util.hnsw.RandomVectorScorer;
*/ */
public final class Lucene91HnswVectorsReader extends KnnVectorsReader { public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
private final Map<String, FieldEntry> fields = new HashMap<>(); private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
private final IndexInput vectorData; private final IndexInput vectorData;
private final IndexInput vectorIndex; private final IndexInput vectorIndex;
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer(); private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
private final FieldInfos fieldInfos;
Lucene91HnswVectorsReader(SegmentReadState state) throws IOException { Lucene91HnswVectorsReader(SegmentReadState state) throws IOException {
int versionMeta = readMetadata(state); int versionMeta = readMetadata(state);
this.fieldInfos = state.fieldInfos;
boolean success = false; boolean success = false;
try { try {
vectorData = vectorData =
@ -154,7 +155,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
} }
FieldEntry fieldEntry = readField(meta, info); FieldEntry fieldEntry = readField(meta, info);
validateFieldEntry(info, fieldEntry); validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry); fields.put(info.number, fieldEntry);
} }
} }
@ -214,13 +215,18 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
CodecUtil.checksumEntireFile(vectorIndex); CodecUtil.checksumEntireFile(vectorIndex);
} }
@Override private FieldEntry getFieldEntry(String field) {
public FloatVectorValues getFloatVectorValues(String field) throws IOException { final FieldInfo info = fieldInfos.fieldInfo(field);
FieldEntry fieldEntry = fields.get(field); final FieldEntry fieldEntry;
if (fieldEntry == null) { if (info == null || (fieldEntry = fields.get(info.number)) == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found"); throw new IllegalArgumentException("field=\"" + field + "\" not found");
} }
return getOffHeapVectorValues(fieldEntry); return fieldEntry;
}
@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
return getOffHeapVectorValues(getFieldEntry(field));
} }
@Override @Override
@ -231,8 +237,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
@Override @Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException { throws IOException {
FieldEntry fieldEntry = fields.get(field); final FieldEntry fieldEntry = getFieldEntry(field);
if (fieldEntry.size() == 0) { if (fieldEntry.size() == 0) {
return; return;
} }

View File

@ -21,8 +21,6 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
@ -34,6 +32,7 @@ import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput; import org.apache.lucene.store.DataInput;
@ -53,13 +52,15 @@ import org.apache.lucene.util.packed.DirectMonotonicReader;
*/ */
public final class Lucene92HnswVectorsReader extends KnnVectorsReader { public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
private final Map<String, FieldEntry> fields = new HashMap<>(); private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
private final IndexInput vectorData; private final IndexInput vectorData;
private final IndexInput vectorIndex; private final IndexInput vectorIndex;
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer(); private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
private final FieldInfos fieldInfos;
Lucene92HnswVectorsReader(SegmentReadState state) throws IOException { Lucene92HnswVectorsReader(SegmentReadState state) throws IOException {
int versionMeta = readMetadata(state); int versionMeta = readMetadata(state);
this.fieldInfos = state.fieldInfos;
boolean success = false; boolean success = false;
try { try {
vectorData = vectorData =
@ -152,7 +153,7 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
} }
FieldEntry fieldEntry = readField(meta, info); FieldEntry fieldEntry = readField(meta, info);
validateFieldEntry(info, fieldEntry); validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry); fields.put(info.number, fieldEntry);
} }
} }
@ -212,13 +213,18 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
CodecUtil.checksumEntireFile(vectorIndex); CodecUtil.checksumEntireFile(vectorIndex);
} }
@Override private FieldEntry getFieldEntry(String field) {
public FloatVectorValues getFloatVectorValues(String field) throws IOException { final FieldInfo info = fieldInfos.fieldInfo(field);
FieldEntry fieldEntry = fields.get(field); final FieldEntry fieldEntry;
if (fieldEntry == null) { if (info == null || (fieldEntry = fields.get(info.number)) == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found"); throw new IllegalArgumentException("field=\"" + field + "\" not found");
} }
return OffHeapFloatVectorValues.load(fieldEntry, vectorData); return fieldEntry;
}
@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
return OffHeapFloatVectorValues.load(getFieldEntry(field), vectorData);
} }
@Override @Override
@ -229,8 +235,7 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
@Override @Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException { throws IOException {
FieldEntry fieldEntry = fields.get(field); final FieldEntry fieldEntry = getFieldEntry(field);
if (fieldEntry.size() == 0) { if (fieldEntry.size() == 0) {
return; return;
} }

View File

@ -21,8 +21,6 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
@ -35,6 +33,7 @@ import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput; import org.apache.lucene.store.DataInput;
@ -54,13 +53,15 @@ import org.apache.lucene.util.packed.DirectMonotonicReader;
*/ */
public final class Lucene94HnswVectorsReader extends KnnVectorsReader { public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
private final Map<String, FieldEntry> fields = new HashMap<>(); private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
private final IndexInput vectorData; private final IndexInput vectorData;
private final IndexInput vectorIndex; private final IndexInput vectorIndex;
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer(); private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
private final FieldInfos fieldInfos;
Lucene94HnswVectorsReader(SegmentReadState state) throws IOException { Lucene94HnswVectorsReader(SegmentReadState state) throws IOException {
int versionMeta = readMetadata(state); int versionMeta = readMetadata(state);
this.fieldInfos = state.fieldInfos;
boolean success = false; boolean success = false;
try { try {
vectorData = vectorData =
@ -153,7 +154,7 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
} }
FieldEntry fieldEntry = readField(meta, info); FieldEntry fieldEntry = readField(meta, info);
validateFieldEntry(info, fieldEntry); validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry); fields.put(info.number, fieldEntry);
} }
} }
@ -230,48 +231,41 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
CodecUtil.checksumEntireFile(vectorIndex); CodecUtil.checksumEntireFile(vectorIndex);
} }
@Override private FieldEntry getFieldEntry(String field, VectorEncoding expectedEncoding) {
public FloatVectorValues getFloatVectorValues(String field) throws IOException { final FieldInfo info = fieldInfos.fieldInfo(field);
FieldEntry fieldEntry = fields.get(field); final FieldEntry fieldEntry;
if (fieldEntry == null) { if (info == null || (fieldEntry = fields.get(info.number)) == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found"); throw new IllegalArgumentException("field=\"" + field + "\" not found");
} }
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) { if (fieldEntry.vectorEncoding != expectedEncoding) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"field=\"" "field=\""
+ field + field
+ "\" is encoded as: " + "\" is encoded as: "
+ fieldEntry.vectorEncoding + fieldEntry.vectorEncoding
+ " expected: " + " expected: "
+ VectorEncoding.FLOAT32); + expectedEncoding);
} }
return fieldEntry;
}
@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
return OffHeapFloatVectorValues.load(fieldEntry, vectorData); return OffHeapFloatVectorValues.load(fieldEntry, vectorData);
} }
@Override @Override
public ByteVectorValues getByteVectorValues(String field) throws IOException { public ByteVectorValues getByteVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field); final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
if (fieldEntry == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
throw new IllegalArgumentException(
"field=\""
+ field
+ "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ VectorEncoding.BYTE);
}
return OffHeapByteVectorValues.load(fieldEntry, vectorData); return OffHeapByteVectorValues.load(fieldEntry, vectorData);
} }
@Override @Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException { throws IOException {
FieldEntry fieldEntry = fields.get(field); final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
if (fieldEntry.size() == 0 || knnCollector.k() == 0) {
if (fieldEntry.size() == 0 || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return; return;
} }
@ -289,9 +283,8 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
@Override @Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException { throws IOException {
FieldEntry fieldEntry = fields.get(field); final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
if (fieldEntry.size() == 0 || knnCollector.k() == 0) {
if (fieldEntry.size() == 0 || fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
return; return;
} }

View File

@ -21,8 +21,6 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
@ -39,6 +37,7 @@ import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput; import org.apache.lucene.store.DataInput;
@ -61,7 +60,7 @@ import org.apache.lucene.util.packed.DirectMonotonicReader;
public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements HnswGraphProvider { public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements HnswGraphProvider {
private final FieldInfos fieldInfos; private final FieldInfos fieldInfos;
private final Map<String, FieldEntry> fields = new HashMap<>(); private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
private final IndexInput vectorData; private final IndexInput vectorData;
private final IndexInput vectorIndex; private final IndexInput vectorIndex;
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer(); private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
@ -161,7 +160,7 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
} }
FieldEntry fieldEntry = readField(meta, info); FieldEntry fieldEntry = readField(meta, info);
validateFieldEntry(info, fieldEntry); validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry); fields.put(info.number, fieldEntry);
} }
} }
@ -238,21 +237,27 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
CodecUtil.checksumEntireFile(vectorIndex); CodecUtil.checksumEntireFile(vectorIndex);
} }
@Override private FieldEntry getFieldEntry(String field, VectorEncoding expectedEncoding) {
public FloatVectorValues getFloatVectorValues(String field) throws IOException { final FieldInfo info = fieldInfos.fieldInfo(field);
FieldEntry fieldEntry = fields.get(field); final FieldEntry fieldEntry;
if (fieldEntry == null) { if (info == null || (fieldEntry = fields.get(info.number)) == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found"); throw new IllegalArgumentException("field=\"" + field + "\" not found");
} }
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) { if (fieldEntry.vectorEncoding != expectedEncoding) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"field=\"" "field=\""
+ field + field
+ "\" is encoded as: " + "\" is encoded as: "
+ fieldEntry.vectorEncoding + fieldEntry.vectorEncoding
+ " expected: " + " expected: "
+ VectorEncoding.FLOAT32); + expectedEncoding);
} }
return fieldEntry;
}
@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
return OffHeapFloatVectorValues.load( return OffHeapFloatVectorValues.load(
fieldEntry.similarityFunction, fieldEntry.similarityFunction,
defaultFlatVectorScorer, defaultFlatVectorScorer,
@ -266,19 +271,7 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
@Override @Override
public ByteVectorValues getByteVectorValues(String field) throws IOException { public ByteVectorValues getByteVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field); final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
if (fieldEntry == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
throw new IllegalArgumentException(
"field=\""
+ field
+ "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ VectorEncoding.BYTE);
}
return OffHeapByteVectorValues.load( return OffHeapByteVectorValues.load(
fieldEntry.similarityFunction, fieldEntry.similarityFunction,
defaultFlatVectorScorer, defaultFlatVectorScorer,
@ -293,11 +286,8 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
@Override @Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException { throws IOException {
FieldEntry fieldEntry = fields.get(field); final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
if (fieldEntry.size() == 0 || knnCollector.k() == 0) {
if (fieldEntry.size() == 0
|| knnCollector.k() == 0
|| fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return; return;
} }
@ -324,11 +314,8 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
@Override @Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException { throws IOException {
FieldEntry fieldEntry = fields.get(field); final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
if (fieldEntry.size() == 0 || knnCollector.k() == 0) {
if (fieldEntry.size() == 0
|| knnCollector.k() == 0
|| fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
return; return;
} }
@ -355,12 +342,12 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
/** Get knn graph values; used for testing */ /** Get knn graph values; used for testing */
@Override @Override
public HnswGraph getGraph(String field) throws IOException { public HnswGraph getGraph(String field) throws IOException {
FieldInfo info = fieldInfos.fieldInfo(field); final FieldInfo info = fieldInfos.fieldInfo(field);
if (info == null) { final FieldEntry entry;
throw new IllegalArgumentException("No such field '" + field + "'"); if (info == null || (entry = fields.get(info.number)) == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
} }
FieldEntry entry = fields.get(field); if (entry.vectorIndexLength > 0) {
if (entry != null && entry.vectorIndexLength > 0) {
return getGraph(entry); return getGraph(entry);
} else { } else {
return HnswGraph.EMPTY; return HnswGraph.EMPTY;

View File

@ -26,8 +26,6 @@ import static org.apache.lucene.codecs.simpletext.SimpleTextKnnVectorsWriter.VEC
import java.io.IOException; import java.io.IOException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CorruptIndexException; import org.apache.lucene.index.CorruptIndexException;
@ -36,6 +34,7 @@ import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer; import org.apache.lucene.search.VectorScorer;
@ -63,7 +62,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
private final SegmentReadState readState; private final SegmentReadState readState;
private final IndexInput dataIn; private final IndexInput dataIn;
private final BytesRefBuilder scratch = new BytesRefBuilder(); private final BytesRefBuilder scratch = new BytesRefBuilder();
private final Map<String, FieldEntry> fieldEntries = new HashMap<>(); private final IntObjectHashMap<FieldEntry> fieldEntries = new IntObjectHashMap<>();
SimpleTextKnnVectorsReader(SegmentReadState readState) throws IOException { SimpleTextKnnVectorsReader(SegmentReadState readState) throws IOException {
this.readState = readState; this.readState = readState;
@ -91,9 +90,9 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
docIds[i] = readInt(in, EMPTY); docIds[i] = readInt(in, EMPTY);
} }
assert fieldEntries.containsKey(fieldName) == false; assert fieldEntries.containsKey(fieldNumber) == false;
fieldEntries.put( fieldEntries.put(
fieldName, fieldNumber,
new FieldEntry( new FieldEntry(
dimension, dimension,
vectorDataOffset, vectorDataOffset,
@ -126,7 +125,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
throw new IllegalStateException( throw new IllegalStateException(
"KNN vectors readers should not be called on fields that don't enable KNN vectors"); "KNN vectors readers should not be called on fields that don't enable KNN vectors");
} }
FieldEntry fieldEntry = fieldEntries.get(field); FieldEntry fieldEntry = fieldEntries.get(info.number);
if (fieldEntry == null) { if (fieldEntry == null) {
// mirror the handling in Lucene90VectorReader#getVectorValues // mirror the handling in Lucene90VectorReader#getVectorValues
// needed to pass TestSimpleTextKnnVectorsFormat#testDeleteAllVectorDocs // needed to pass TestSimpleTextKnnVectorsFormat#testDeleteAllVectorDocs
@ -159,7 +158,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
throw new IllegalStateException( throw new IllegalStateException(
"KNN vectors readers should not be called on fields that don't enable KNN vectors"); "KNN vectors readers should not be called on fields that don't enable KNN vectors");
} }
FieldEntry fieldEntry = fieldEntries.get(field); FieldEntry fieldEntry = fieldEntries.get(info.number);
if (fieldEntry == null) { if (fieldEntry == null) {
// mirror the handling in Lucene90VectorReader#getVectorValues // mirror the handling in Lucene90VectorReader#getVectorValues
// needed to pass TestSimpleTextKnnVectorsFormat#testDeleteAllVectorDocs // needed to pass TestSimpleTextKnnVectorsFormat#testDeleteAllVectorDocs

View File

@ -21,8 +21,6 @@ import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSi
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding;
import java.io.IOException; import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader; import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
@ -38,6 +36,7 @@ import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexInput;
@ -56,13 +55,15 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
private static final long SHALLOW_SIZE = private static final long SHALLOW_SIZE =
RamUsageEstimator.shallowSizeOfInstance(Lucene99FlatVectorsFormat.class); RamUsageEstimator.shallowSizeOfInstance(Lucene99FlatVectorsFormat.class);
private final Map<String, FieldEntry> fields = new HashMap<>(); private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
private final IndexInput vectorData; private final IndexInput vectorData;
private final FieldInfos fieldInfos;
public Lucene99FlatVectorsReader(SegmentReadState state, FlatVectorsScorer scorer) public Lucene99FlatVectorsReader(SegmentReadState state, FlatVectorsScorer scorer)
throws IOException { throws IOException {
super(scorer); super(scorer);
int versionMeta = readMetadata(state); int versionMeta = readMetadata(state);
this.fieldInfos = state.fieldInfos;
boolean success = false; boolean success = false;
try { try {
vectorData = vectorData =
@ -155,15 +156,13 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
} }
FieldEntry fieldEntry = FieldEntry.create(meta, info); FieldEntry fieldEntry = FieldEntry.create(meta, info);
fields.put(info.name, fieldEntry); fields.put(info.number, fieldEntry);
} }
} }
@Override @Override
public long ramBytesUsed() { public long ramBytesUsed() {
return Lucene99FlatVectorsReader.SHALLOW_SIZE return Lucene99FlatVectorsReader.SHALLOW_SIZE + fields.ramBytesUsed();
+ RamUsageEstimator.sizeOfMap(
fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class));
} }
@Override @Override
@ -171,21 +170,27 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
CodecUtil.checksumEntireFile(vectorData); CodecUtil.checksumEntireFile(vectorData);
} }
@Override private FieldEntry getFieldEntry(String field, VectorEncoding expectedEncoding) {
public FloatVectorValues getFloatVectorValues(String field) throws IOException { final FieldInfo info = fieldInfos.fieldInfo(field);
FieldEntry fieldEntry = fields.get(field); final FieldEntry fieldEntry;
if (fieldEntry == null) { if (info == null || (fieldEntry = fields.get(info.number)) == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found"); throw new IllegalArgumentException("field=\"" + field + "\" not found");
} }
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) { if (fieldEntry.vectorEncoding != expectedEncoding) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"field=\"" "field=\""
+ field + field
+ "\" is encoded as: " + "\" is encoded as: "
+ fieldEntry.vectorEncoding + fieldEntry.vectorEncoding
+ " expected: " + " expected: "
+ VectorEncoding.FLOAT32); + expectedEncoding);
} }
return fieldEntry;
}
@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
return OffHeapFloatVectorValues.load( return OffHeapFloatVectorValues.load(
fieldEntry.similarityFunction, fieldEntry.similarityFunction,
vectorScorer, vectorScorer,
@ -199,19 +204,7 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
@Override @Override
public ByteVectorValues getByteVectorValues(String field) throws IOException { public ByteVectorValues getByteVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field); final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
if (fieldEntry == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
throw new IllegalArgumentException(
"field=\""
+ field
+ "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ VectorEncoding.BYTE);
}
return OffHeapByteVectorValues.load( return OffHeapByteVectorValues.load(
fieldEntry.similarityFunction, fieldEntry.similarityFunction,
vectorScorer, vectorScorer,
@ -225,10 +218,7 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
@Override @Override
public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException { public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException {
FieldEntry fieldEntry = fields.get(field); final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return null;
}
return vectorScorer.getRandomVectorScorer( return vectorScorer.getRandomVectorScorer(
fieldEntry.similarityFunction, fieldEntry.similarityFunction,
OffHeapFloatVectorValues.load( OffHeapFloatVectorValues.load(
@ -245,10 +235,7 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
@Override @Override
public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException { public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException {
FieldEntry fieldEntry = fields.get(field); final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
return null;
}
return vectorScorer.getRandomVectorScorer( return vectorScorer.getRandomVectorScorer(
fieldEntry.similarityFunction, fieldEntry.similarityFunction,
OffHeapByteVectorValues.load( OffHeapByteVectorValues.load(

View File

@ -21,9 +21,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader; import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
@ -37,6 +35,7 @@ import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput; import org.apache.lucene.store.DataInput;
@ -70,7 +69,7 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
private final FlatVectorsReader flatVectorsReader; private final FlatVectorsReader flatVectorsReader;
private final FieldInfos fieldInfos; private final FieldInfos fieldInfos;
private final Map<String, FieldEntry> fields = new HashMap<>(); private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
private final IndexInput vectorIndex; private final IndexInput vectorIndex;
public Lucene99HnswVectorsReader(SegmentReadState state, FlatVectorsReader flatVectorsReader) public Lucene99HnswVectorsReader(SegmentReadState state, FlatVectorsReader flatVectorsReader)
@ -162,7 +161,7 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
} }
FieldEntry fieldEntry = readField(meta, info); FieldEntry fieldEntry = readField(meta, info);
validateFieldEntry(info, fieldEntry); validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry); fields.put(info.number, fieldEntry);
} }
} }
@ -225,8 +224,7 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
@Override @Override
public long ramBytesUsed() { public long ramBytesUsed() {
return Lucene99HnswVectorsReader.SHALLOW_SIZE return Lucene99HnswVectorsReader.SHALLOW_SIZE
+ RamUsageEstimator.sizeOfMap( + fields.ramBytesUsed()
fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class))
+ flatVectorsReader.ramBytesUsed(); + flatVectorsReader.ramBytesUsed();
} }
@ -246,25 +244,43 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
return flatVectorsReader.getByteVectorValues(field); return flatVectorsReader.getByteVectorValues(field);
} }
private FieldEntry getFieldEntry(String field, VectorEncoding expectedEncoding) {
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry fieldEntry;
if (info == null || (fieldEntry = fields.get(info.number)) == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
if (fieldEntry.vectorEncoding != expectedEncoding) {
throw new IllegalArgumentException(
"field=\""
+ field
+ "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ expectedEncoding);
}
return fieldEntry;
}
@Override @Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException { throws IOException {
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
search( search(
fields.get(field), fieldEntry,
knnCollector, knnCollector,
acceptDocs, acceptDocs,
VectorEncoding.FLOAT32,
() -> flatVectorsReader.getRandomVectorScorer(field, target)); () -> flatVectorsReader.getRandomVectorScorer(field, target));
} }
@Override @Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException { throws IOException {
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
search( search(
fields.get(field), fieldEntry,
knnCollector, knnCollector,
acceptDocs, acceptDocs,
VectorEncoding.BYTE,
() -> flatVectorsReader.getRandomVectorScorer(field, target)); () -> flatVectorsReader.getRandomVectorScorer(field, target));
} }
@ -272,13 +288,10 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
FieldEntry fieldEntry, FieldEntry fieldEntry,
KnnCollector knnCollector, KnnCollector knnCollector,
Bits acceptDocs, Bits acceptDocs,
VectorEncoding vectorEncoding,
IOSupplier<RandomVectorScorer> scorerSupplier) IOSupplier<RandomVectorScorer> scorerSupplier)
throws IOException { throws IOException {
if (fieldEntry.size() == 0 if (fieldEntry.size() == 0 || knnCollector.k() == 0) {
|| knnCollector.k() == 0
|| fieldEntry.vectorEncoding != vectorEncoding) {
return; return;
} }
final RandomVectorScorer scorer = scorerSupplier.get(); final RandomVectorScorer scorer = scorerSupplier.get();
@ -304,12 +317,12 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
@Override @Override
public HnswGraph getGraph(String field) throws IOException { public HnswGraph getGraph(String field) throws IOException {
FieldInfo info = fieldInfos.fieldInfo(field); final FieldInfo info = fieldInfos.fieldInfo(field);
if (info == null) { final FieldEntry entry;
throw new IllegalArgumentException("No such field '" + field + "'"); if (info == null || (entry = fields.get(info.number)) == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
} }
FieldEntry entry = fields.get(field); if (entry.vectorIndexLength > 0) {
if (entry != null && entry.vectorIndexLength > 0) {
return getGraph(entry); return getGraph(entry);
} else { } else {
return HnswGraph.EMPTY; return HnswGraph.EMPTY;

View File

@ -21,8 +21,6 @@ import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSi
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding;
import java.io.IOException; import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader; import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
@ -36,6 +34,7 @@ import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.VectorScorer; import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IOContext;
@ -59,15 +58,17 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
private static final long SHALLOW_SIZE = private static final long SHALLOW_SIZE =
RamUsageEstimator.shallowSizeOfInstance(Lucene99ScalarQuantizedVectorsReader.class); RamUsageEstimator.shallowSizeOfInstance(Lucene99ScalarQuantizedVectorsReader.class);
private final Map<String, FieldEntry> fields = new HashMap<>(); private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
private final IndexInput quantizedVectorData; private final IndexInput quantizedVectorData;
private final FlatVectorsReader rawVectorsReader; private final FlatVectorsReader rawVectorsReader;
private final FieldInfos fieldInfos;
public Lucene99ScalarQuantizedVectorsReader( public Lucene99ScalarQuantizedVectorsReader(
SegmentReadState state, FlatVectorsReader rawVectorsReader, FlatVectorsScorer scorer) SegmentReadState state, FlatVectorsReader rawVectorsReader, FlatVectorsScorer scorer)
throws IOException { throws IOException {
super(scorer); super(scorer);
this.rawVectorsReader = rawVectorsReader; this.rawVectorsReader = rawVectorsReader;
this.fieldInfos = state.fieldInfos;
int versionMeta = -1; int versionMeta = -1;
String metaFileName = String metaFileName =
IndexFileNames.segmentFileName( IndexFileNames.segmentFileName(
@ -118,7 +119,7 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
} }
FieldEntry fieldEntry = readField(meta, versionMeta, info); FieldEntry fieldEntry = readField(meta, versionMeta, info);
validateFieldEntry(info, fieldEntry); validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry); fields.put(info.number, fieldEntry);
} }
} }
@ -163,10 +164,10 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
CodecUtil.checksumEntireFile(quantizedVectorData); CodecUtil.checksumEntireFile(quantizedVectorData);
} }
@Override private FieldEntry getFieldEntry(String field) {
public FloatVectorValues getFloatVectorValues(String field) throws IOException { final FieldInfo info = fieldInfos.fieldInfo(field);
FieldEntry fieldEntry = fields.get(field); final FieldEntry fieldEntry;
if (fieldEntry == null) { if (info == null || (fieldEntry = fields.get(info.number)) == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found"); throw new IllegalArgumentException("field=\"" + field + "\" not found");
} }
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) { if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
@ -178,6 +179,12 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
+ " expected: " + " expected: "
+ VectorEncoding.FLOAT32); + VectorEncoding.FLOAT32);
} }
return fieldEntry;
}
@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
final FieldEntry fieldEntry = getFieldEntry(field);
final FloatVectorValues rawVectorValues = rawVectorsReader.getFloatVectorValues(field); final FloatVectorValues rawVectorValues = rawVectorsReader.getFloatVectorValues(field);
OffHeapQuantizedByteVectorValues quantizedByteVectorValues = OffHeapQuantizedByteVectorValues quantizedByteVectorValues =
OffHeapQuantizedByteVectorValues.load( OffHeapQuantizedByteVectorValues.load(
@ -241,10 +248,7 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
@Override @Override
public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException { public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException {
FieldEntry fieldEntry = fields.get(field); final FieldEntry fieldEntry = getFieldEntry(field);
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return null;
}
if (fieldEntry.scalarQuantizer == null) { if (fieldEntry.scalarQuantizer == null) {
return rawVectorsReader.getRandomVectorScorer(field, target); return rawVectorsReader.getRandomVectorScorer(field, target);
} }
@ -275,12 +279,7 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
@Override @Override
public long ramBytesUsed() { public long ramBytesUsed() {
long size = SHALLOW_SIZE; return SHALLOW_SIZE + fields.ramBytesUsed() + rawVectorsReader.ramBytesUsed();
size +=
RamUsageEstimator.sizeOfMap(
fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class));
size += rawVectorsReader.ramBytesUsed();
return size;
} }
private FieldEntry readField(IndexInput input, int versionMeta, FieldInfo info) private FieldEntry readField(IndexInput input, int versionMeta, FieldInfo info)
@ -301,11 +300,8 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
} }
@Override @Override
public QuantizedByteVectorValues getQuantizedVectorValues(String fieldName) throws IOException { public QuantizedByteVectorValues getQuantizedVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(fieldName); final FieldEntry fieldEntry = getFieldEntry(field);
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return null;
}
return OffHeapQuantizedByteVectorValues.load( return OffHeapQuantizedByteVectorValues.load(
fieldEntry.ordToDoc, fieldEntry.ordToDoc,
fieldEntry.dimension, fieldEntry.dimension,
@ -320,11 +316,8 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
} }
@Override @Override
public ScalarQuantizer getQuantizationState(String fieldName) { public ScalarQuantizer getQuantizationState(String field) {
FieldEntry fieldEntry = fields.get(fieldName); final FieldEntry fieldEntry = getFieldEntry(field);
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return null;
}
return fieldEntry.scalarQuantizer; return fieldEntry.scalarQuantizer;
} }

View File

@ -19,7 +19,9 @@ package org.apache.lucene.codecs.perfield;
import java.io.Closeable; import java.io.Closeable;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.ServiceLoader; import java.util.ServiceLoader;
import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.KnnFieldVectorsWriter;
@ -28,11 +30,14 @@ import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.MergeState; import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter; import org.apache.lucene.index.Sorter;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.internal.hppc.ObjectCursor;
import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.util.Bits; import org.apache.lucene.util.Bits;
import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.IOUtils;
@ -186,7 +191,8 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
/** VectorReader that can wrap multiple delegate readers, selected by field. */ /** VectorReader that can wrap multiple delegate readers, selected by field. */
public static class FieldsReader extends KnnVectorsReader { public static class FieldsReader extends KnnVectorsReader {
private final Map<String, KnnVectorsReader> fields = new HashMap<>(); private final IntObjectHashMap<KnnVectorsReader> fields = new IntObjectHashMap<>();
private final FieldInfos fieldInfos;
/** /**
* Create a FieldsReader over a segment, opening VectorReaders for each KnnVectorsFormat * Create a FieldsReader over a segment, opening VectorReaders for each KnnVectorsFormat
@ -196,7 +202,7 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
* @throws IOException if one of the delegate readers throws * @throws IOException if one of the delegate readers throws
*/ */
public FieldsReader(final SegmentReadState readState) throws IOException { public FieldsReader(final SegmentReadState readState) throws IOException {
this.fieldInfos = readState.fieldInfos;
// Init each unique format: // Init each unique format:
boolean success = false; boolean success = false;
Map<String, KnnVectorsReader> formats = new HashMap<>(); Map<String, KnnVectorsReader> formats = new HashMap<>();
@ -221,7 +227,7 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
segmentSuffix, segmentSuffix,
format.fieldsReader(new SegmentReadState(readState, segmentSuffix))); format.fieldsReader(new SegmentReadState(readState, segmentSuffix)));
} }
fields.put(fieldName, formats.get(segmentSuffix)); fields.put(fi.number, formats.get(segmentSuffix));
} }
} }
} }
@ -239,51 +245,69 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
* @param field the name of a numeric vector field * @param field the name of a numeric vector field
*/ */
public KnnVectorsReader getFieldReader(String field) { public KnnVectorsReader getFieldReader(String field) {
return fields.get(field); final FieldInfo info = fieldInfos.fieldInfo(field);
if (info == null) {
return null;
}
return fields.get(info.number);
} }
@Override @Override
public void checkIntegrity() throws IOException { public void checkIntegrity() throws IOException {
for (KnnVectorsReader reader : fields.values()) { for (ObjectCursor<KnnVectorsReader> cursor : fields.values()) {
reader.checkIntegrity(); cursor.value.checkIntegrity();
} }
} }
@Override @Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException { public FloatVectorValues getFloatVectorValues(String field) throws IOException {
KnnVectorsReader knnVectorsReader = fields.get(field); final FieldInfo info = fieldInfos.fieldInfo(field);
if (knnVectorsReader == null) { final KnnVectorsReader reader;
if (info == null || (reader = fields.get(info.number)) == null) {
return null; return null;
} else {
return knnVectorsReader.getFloatVectorValues(field);
} }
return reader.getFloatVectorValues(field);
} }
@Override @Override
public ByteVectorValues getByteVectorValues(String field) throws IOException { public ByteVectorValues getByteVectorValues(String field) throws IOException {
KnnVectorsReader knnVectorsReader = fields.get(field); final FieldInfo info = fieldInfos.fieldInfo(field);
if (knnVectorsReader == null) { final KnnVectorsReader reader;
if (info == null || (reader = fields.get(info.number)) == null) {
return null; return null;
} else {
return knnVectorsReader.getByteVectorValues(field);
} }
return reader.getByteVectorValues(field);
} }
@Override @Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException { throws IOException {
fields.get(field).search(field, target, knnCollector, acceptDocs); final FieldInfo info = fieldInfos.fieldInfo(field);
final KnnVectorsReader reader;
if (info == null || (reader = fields.get(info.number)) == null) {
return;
}
reader.search(field, target, knnCollector, acceptDocs);
} }
@Override @Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException { throws IOException {
fields.get(field).search(field, target, knnCollector, acceptDocs); final FieldInfo info = fieldInfos.fieldInfo(field);
final KnnVectorsReader reader;
if (info == null || (reader = fields.get(info.number)) == null) {
return;
}
reader.search(field, target, knnCollector, acceptDocs);
} }
@Override @Override
public void close() throws IOException { public void close() throws IOException {
IOUtils.close(fields.values()); List<KnnVectorsReader> readers = new ArrayList<>(fields.size());
for (ObjectCursor<KnnVectorsReader> cursor : fields.values()) {
readers.add(cursor.value);
}
IOUtils.close(readers);
} }
} }