mirror of https://github.com/apache/lucene.git
Replace Map<String,Object> with IntObjectHashMap for KnnVectorsReader (#13763)
This commit is contained in:
parent
26e0737e40
commit
494b16063e
|
@ -82,6 +82,8 @@ Optimizations
|
|||
|
||||
* GITHUB#13958: Speed up advancing within a block. (Adrien Grand)
|
||||
|
||||
* GITHUB#13763: Replace Map<String,Object> with IntObjectHashMap for KnnVectorsReader (Pan Guixin)
|
||||
|
||||
Bug Fixes
|
||||
---------------------
|
||||
* GITHUB#13832: Fixed an issue where the DefaultPassageFormatter.format method did not format passages as intended
|
||||
|
|
|
@ -20,8 +20,6 @@ package org.apache.lucene.backward_codecs.lucene90;
|
|||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.SplittableRandom;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
|
@ -33,6 +31,7 @@ import org.apache.lucene.index.FloatVectorValues;
|
|||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.internal.hppc.IntObjectHashMap;
|
||||
import org.apache.lucene.search.KnnCollector;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
import org.apache.lucene.store.ChecksumIndexInput;
|
||||
|
@ -50,14 +49,16 @@ import org.apache.lucene.util.hnsw.NeighborQueue;
|
|||
*/
|
||||
public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
||||
|
||||
private final Map<String, FieldEntry> fields = new HashMap<>();
|
||||
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
|
||||
private final IndexInput vectorData;
|
||||
private final IndexInput vectorIndex;
|
||||
private final long checksumSeed;
|
||||
private final FieldInfos fieldInfos;
|
||||
|
||||
Lucene90HnswVectorsReader(SegmentReadState state) throws IOException {
|
||||
int versionMeta = readMetadata(state);
|
||||
long[] checksumRef = new long[1];
|
||||
this.fieldInfos = state.fieldInfos;
|
||||
boolean success = false;
|
||||
try {
|
||||
vectorData =
|
||||
|
@ -158,7 +159,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
|||
|
||||
FieldEntry fieldEntry = readField(meta, info);
|
||||
validateFieldEntry(info, fieldEntry);
|
||||
fields.put(info.name, fieldEntry);
|
||||
fields.put(info.number, fieldEntry);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -218,13 +219,18 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
|||
CodecUtil.checksumEntireFile(vectorIndex);
|
||||
}
|
||||
|
||||
@Override
|
||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
if (fieldEntry == null) {
|
||||
private FieldEntry getFieldEntry(String field) {
|
||||
final FieldInfo info = fieldInfos.fieldInfo(field);
|
||||
final FieldEntry fieldEntry;
|
||||
if (info == null || (fieldEntry = fields.get(info.number)) == null) {
|
||||
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
||||
}
|
||||
return getOffHeapVectorValues(fieldEntry);
|
||||
return fieldEntry;
|
||||
}
|
||||
|
||||
@Override
|
||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||
return getOffHeapVectorValues(getFieldEntry(field));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -235,8 +241,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
|||
@Override
|
||||
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
|
||||
throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
|
||||
final FieldEntry fieldEntry = getFieldEntry(field);
|
||||
if (fieldEntry.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -21,8 +21,6 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.function.IntUnaryOperator;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
|
@ -35,6 +33,7 @@ import org.apache.lucene.index.FloatVectorValues;
|
|||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.internal.hppc.IntObjectHashMap;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.KnnCollector;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
|
@ -55,13 +54,15 @@ import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
|||
*/
|
||||
public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
||||
|
||||
private final Map<String, FieldEntry> fields = new HashMap<>();
|
||||
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
|
||||
private final IndexInput vectorData;
|
||||
private final IndexInput vectorIndex;
|
||||
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
|
||||
private final FieldInfos fieldInfos;
|
||||
|
||||
Lucene91HnswVectorsReader(SegmentReadState state) throws IOException {
|
||||
int versionMeta = readMetadata(state);
|
||||
this.fieldInfos = state.fieldInfos;
|
||||
boolean success = false;
|
||||
try {
|
||||
vectorData =
|
||||
|
@ -154,7 +155,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
FieldEntry fieldEntry = readField(meta, info);
|
||||
validateFieldEntry(info, fieldEntry);
|
||||
fields.put(info.name, fieldEntry);
|
||||
fields.put(info.number, fieldEntry);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -214,13 +215,18 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
|||
CodecUtil.checksumEntireFile(vectorIndex);
|
||||
}
|
||||
|
||||
@Override
|
||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
if (fieldEntry == null) {
|
||||
private FieldEntry getFieldEntry(String field) {
|
||||
final FieldInfo info = fieldInfos.fieldInfo(field);
|
||||
final FieldEntry fieldEntry;
|
||||
if (info == null || (fieldEntry = fields.get(info.number)) == null) {
|
||||
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
||||
}
|
||||
return getOffHeapVectorValues(fieldEntry);
|
||||
return fieldEntry;
|
||||
}
|
||||
|
||||
@Override
|
||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||
return getOffHeapVectorValues(getFieldEntry(field));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -231,8 +237,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
|||
@Override
|
||||
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
|
||||
throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
|
||||
final FieldEntry fieldEntry = getFieldEntry(field);
|
||||
if (fieldEntry.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -21,8 +21,6 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
|
||||
|
@ -34,6 +32,7 @@ import org.apache.lucene.index.FloatVectorValues;
|
|||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.internal.hppc.IntObjectHashMap;
|
||||
import org.apache.lucene.search.KnnCollector;
|
||||
import org.apache.lucene.store.ChecksumIndexInput;
|
||||
import org.apache.lucene.store.DataInput;
|
||||
|
@ -53,13 +52,15 @@ import org.apache.lucene.util.packed.DirectMonotonicReader;
|
|||
*/
|
||||
public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
|
||||
|
||||
private final Map<String, FieldEntry> fields = new HashMap<>();
|
||||
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
|
||||
private final IndexInput vectorData;
|
||||
private final IndexInput vectorIndex;
|
||||
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
|
||||
private final FieldInfos fieldInfos;
|
||||
|
||||
Lucene92HnswVectorsReader(SegmentReadState state) throws IOException {
|
||||
int versionMeta = readMetadata(state);
|
||||
this.fieldInfos = state.fieldInfos;
|
||||
boolean success = false;
|
||||
try {
|
||||
vectorData =
|
||||
|
@ -152,7 +153,7 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
FieldEntry fieldEntry = readField(meta, info);
|
||||
validateFieldEntry(info, fieldEntry);
|
||||
fields.put(info.name, fieldEntry);
|
||||
fields.put(info.number, fieldEntry);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -212,13 +213,18 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
|
|||
CodecUtil.checksumEntireFile(vectorIndex);
|
||||
}
|
||||
|
||||
@Override
|
||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
if (fieldEntry == null) {
|
||||
private FieldEntry getFieldEntry(String field) {
|
||||
final FieldInfo info = fieldInfos.fieldInfo(field);
|
||||
final FieldEntry fieldEntry;
|
||||
if (info == null || (fieldEntry = fields.get(info.number)) == null) {
|
||||
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
||||
}
|
||||
return OffHeapFloatVectorValues.load(fieldEntry, vectorData);
|
||||
return fieldEntry;
|
||||
}
|
||||
|
||||
@Override
|
||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||
return OffHeapFloatVectorValues.load(getFieldEntry(field), vectorData);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -229,8 +235,7 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
|
|||
@Override
|
||||
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
|
||||
throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
|
||||
final FieldEntry fieldEntry = getFieldEntry(field);
|
||||
if (fieldEntry.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -21,8 +21,6 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
|
||||
|
@ -35,6 +33,7 @@ import org.apache.lucene.index.IndexFileNames;
|
|||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.internal.hppc.IntObjectHashMap;
|
||||
import org.apache.lucene.search.KnnCollector;
|
||||
import org.apache.lucene.store.ChecksumIndexInput;
|
||||
import org.apache.lucene.store.DataInput;
|
||||
|
@ -54,13 +53,15 @@ import org.apache.lucene.util.packed.DirectMonotonicReader;
|
|||
*/
|
||||
public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
||||
|
||||
private final Map<String, FieldEntry> fields = new HashMap<>();
|
||||
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
|
||||
private final IndexInput vectorData;
|
||||
private final IndexInput vectorIndex;
|
||||
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
|
||||
private final FieldInfos fieldInfos;
|
||||
|
||||
Lucene94HnswVectorsReader(SegmentReadState state) throws IOException {
|
||||
int versionMeta = readMetadata(state);
|
||||
this.fieldInfos = state.fieldInfos;
|
||||
boolean success = false;
|
||||
try {
|
||||
vectorData =
|
||||
|
@ -153,7 +154,7 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
FieldEntry fieldEntry = readField(meta, info);
|
||||
validateFieldEntry(info, fieldEntry);
|
||||
fields.put(info.name, fieldEntry);
|
||||
fields.put(info.number, fieldEntry);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -230,48 +231,41 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
|||
CodecUtil.checksumEntireFile(vectorIndex);
|
||||
}
|
||||
|
||||
@Override
|
||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
if (fieldEntry == null) {
|
||||
private FieldEntry getFieldEntry(String field, VectorEncoding expectedEncoding) {
|
||||
final FieldInfo info = fieldInfos.fieldInfo(field);
|
||||
final FieldEntry fieldEntry;
|
||||
if (info == null || (fieldEntry = fields.get(info.number)) == null) {
|
||||
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
||||
}
|
||||
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
||||
if (fieldEntry.vectorEncoding != expectedEncoding) {
|
||||
throw new IllegalArgumentException(
|
||||
"field=\""
|
||||
+ field
|
||||
+ "\" is encoded as: "
|
||||
+ fieldEntry.vectorEncoding
|
||||
+ " expected: "
|
||||
+ VectorEncoding.FLOAT32);
|
||||
+ expectedEncoding);
|
||||
}
|
||||
return fieldEntry;
|
||||
}
|
||||
|
||||
@Override
|
||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
|
||||
return OffHeapFloatVectorValues.load(fieldEntry, vectorData);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
if (fieldEntry == null) {
|
||||
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
||||
}
|
||||
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
||||
throw new IllegalArgumentException(
|
||||
"field=\""
|
||||
+ field
|
||||
+ "\" is encoded as: "
|
||||
+ fieldEntry.vectorEncoding
|
||||
+ " expected: "
|
||||
+ VectorEncoding.BYTE);
|
||||
}
|
||||
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
|
||||
return OffHeapByteVectorValues.load(fieldEntry, vectorData);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
|
||||
throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
|
||||
if (fieldEntry.size() == 0 || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
||||
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
|
||||
if (fieldEntry.size() == 0 || knnCollector.k() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -289,9 +283,8 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
|||
@Override
|
||||
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
|
||||
throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
|
||||
if (fieldEntry.size() == 0 || fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
||||
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
|
||||
if (fieldEntry.size() == 0 || knnCollector.k() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -21,8 +21,6 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
|
||||
|
@ -39,6 +37,7 @@ import org.apache.lucene.index.IndexFileNames;
|
|||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.internal.hppc.IntObjectHashMap;
|
||||
import org.apache.lucene.search.KnnCollector;
|
||||
import org.apache.lucene.store.ChecksumIndexInput;
|
||||
import org.apache.lucene.store.DataInput;
|
||||
|
@ -61,7 +60,7 @@ import org.apache.lucene.util.packed.DirectMonotonicReader;
|
|||
public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements HnswGraphProvider {
|
||||
|
||||
private final FieldInfos fieldInfos;
|
||||
private final Map<String, FieldEntry> fields = new HashMap<>();
|
||||
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
|
||||
private final IndexInput vectorData;
|
||||
private final IndexInput vectorIndex;
|
||||
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
|
||||
|
@ -161,7 +160,7 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
|
|||
}
|
||||
FieldEntry fieldEntry = readField(meta, info);
|
||||
validateFieldEntry(info, fieldEntry);
|
||||
fields.put(info.name, fieldEntry);
|
||||
fields.put(info.number, fieldEntry);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -238,21 +237,27 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
|
|||
CodecUtil.checksumEntireFile(vectorIndex);
|
||||
}
|
||||
|
||||
@Override
|
||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
if (fieldEntry == null) {
|
||||
private FieldEntry getFieldEntry(String field, VectorEncoding expectedEncoding) {
|
||||
final FieldInfo info = fieldInfos.fieldInfo(field);
|
||||
final FieldEntry fieldEntry;
|
||||
if (info == null || (fieldEntry = fields.get(info.number)) == null) {
|
||||
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
||||
}
|
||||
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
||||
if (fieldEntry.vectorEncoding != expectedEncoding) {
|
||||
throw new IllegalArgumentException(
|
||||
"field=\""
|
||||
+ field
|
||||
+ "\" is encoded as: "
|
||||
+ fieldEntry.vectorEncoding
|
||||
+ " expected: "
|
||||
+ VectorEncoding.FLOAT32);
|
||||
+ expectedEncoding);
|
||||
}
|
||||
return fieldEntry;
|
||||
}
|
||||
|
||||
@Override
|
||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
|
||||
return OffHeapFloatVectorValues.load(
|
||||
fieldEntry.similarityFunction,
|
||||
defaultFlatVectorScorer,
|
||||
|
@ -266,19 +271,7 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
|
|||
|
||||
@Override
|
||||
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
if (fieldEntry == null) {
|
||||
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
||||
}
|
||||
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
||||
throw new IllegalArgumentException(
|
||||
"field=\""
|
||||
+ field
|
||||
+ "\" is encoded as: "
|
||||
+ fieldEntry.vectorEncoding
|
||||
+ " expected: "
|
||||
+ VectorEncoding.BYTE);
|
||||
}
|
||||
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
|
||||
return OffHeapByteVectorValues.load(
|
||||
fieldEntry.similarityFunction,
|
||||
defaultFlatVectorScorer,
|
||||
|
@ -293,11 +286,8 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
|
|||
@Override
|
||||
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
|
||||
throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
|
||||
if (fieldEntry.size() == 0
|
||||
|| knnCollector.k() == 0
|
||||
|| fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
||||
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
|
||||
if (fieldEntry.size() == 0 || knnCollector.k() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -324,11 +314,8 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
|
|||
@Override
|
||||
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
|
||||
throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
|
||||
if (fieldEntry.size() == 0
|
||||
|| knnCollector.k() == 0
|
||||
|| fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
||||
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
|
||||
if (fieldEntry.size() == 0 || knnCollector.k() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -355,12 +342,12 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
|
|||
/** Get knn graph values; used for testing */
|
||||
@Override
|
||||
public HnswGraph getGraph(String field) throws IOException {
|
||||
FieldInfo info = fieldInfos.fieldInfo(field);
|
||||
if (info == null) {
|
||||
throw new IllegalArgumentException("No such field '" + field + "'");
|
||||
final FieldInfo info = fieldInfos.fieldInfo(field);
|
||||
final FieldEntry entry;
|
||||
if (info == null || (entry = fields.get(info.number)) == null) {
|
||||
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
||||
}
|
||||
FieldEntry entry = fields.get(field);
|
||||
if (entry != null && entry.vectorIndexLength > 0) {
|
||||
if (entry.vectorIndexLength > 0) {
|
||||
return getGraph(entry);
|
||||
} else {
|
||||
return HnswGraph.EMPTY;
|
||||
|
|
|
@ -26,8 +26,6 @@ import static org.apache.lucene.codecs.simpletext.SimpleTextKnnVectorsWriter.VEC
|
|||
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.CorruptIndexException;
|
||||
|
@ -36,6 +34,7 @@ import org.apache.lucene.index.FloatVectorValues;
|
|||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.internal.hppc.IntObjectHashMap;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.KnnCollector;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
|
@ -63,7 +62,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
private final SegmentReadState readState;
|
||||
private final IndexInput dataIn;
|
||||
private final BytesRefBuilder scratch = new BytesRefBuilder();
|
||||
private final Map<String, FieldEntry> fieldEntries = new HashMap<>();
|
||||
private final IntObjectHashMap<FieldEntry> fieldEntries = new IntObjectHashMap<>();
|
||||
|
||||
SimpleTextKnnVectorsReader(SegmentReadState readState) throws IOException {
|
||||
this.readState = readState;
|
||||
|
@ -91,9 +90,9 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
for (int i = 0; i < size; i++) {
|
||||
docIds[i] = readInt(in, EMPTY);
|
||||
}
|
||||
assert fieldEntries.containsKey(fieldName) == false;
|
||||
assert fieldEntries.containsKey(fieldNumber) == false;
|
||||
fieldEntries.put(
|
||||
fieldName,
|
||||
fieldNumber,
|
||||
new FieldEntry(
|
||||
dimension,
|
||||
vectorDataOffset,
|
||||
|
@ -126,7 +125,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
throw new IllegalStateException(
|
||||
"KNN vectors readers should not be called on fields that don't enable KNN vectors");
|
||||
}
|
||||
FieldEntry fieldEntry = fieldEntries.get(field);
|
||||
FieldEntry fieldEntry = fieldEntries.get(info.number);
|
||||
if (fieldEntry == null) {
|
||||
// mirror the handling in Lucene90VectorReader#getVectorValues
|
||||
// needed to pass TestSimpleTextKnnVectorsFormat#testDeleteAllVectorDocs
|
||||
|
@ -159,7 +158,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
throw new IllegalStateException(
|
||||
"KNN vectors readers should not be called on fields that don't enable KNN vectors");
|
||||
}
|
||||
FieldEntry fieldEntry = fieldEntries.get(field);
|
||||
FieldEntry fieldEntry = fieldEntries.get(info.number);
|
||||
if (fieldEntry == null) {
|
||||
// mirror the handling in Lucene90VectorReader#getVectorValues
|
||||
// needed to pass TestSimpleTextKnnVectorsFormat#testDeleteAllVectorDocs
|
||||
|
|
|
@ -21,8 +21,6 @@ import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSi
|
|||
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
|
||||
|
@ -38,6 +36,7 @@ import org.apache.lucene.index.IndexFileNames;
|
|||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.internal.hppc.IntObjectHashMap;
|
||||
import org.apache.lucene.store.ChecksumIndexInput;
|
||||
import org.apache.lucene.store.IOContext;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
|
@ -56,13 +55,15 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
|
|||
private static final long SHALLOW_SIZE =
|
||||
RamUsageEstimator.shallowSizeOfInstance(Lucene99FlatVectorsFormat.class);
|
||||
|
||||
private final Map<String, FieldEntry> fields = new HashMap<>();
|
||||
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
|
||||
private final IndexInput vectorData;
|
||||
private final FieldInfos fieldInfos;
|
||||
|
||||
public Lucene99FlatVectorsReader(SegmentReadState state, FlatVectorsScorer scorer)
|
||||
throws IOException {
|
||||
super(scorer);
|
||||
int versionMeta = readMetadata(state);
|
||||
this.fieldInfos = state.fieldInfos;
|
||||
boolean success = false;
|
||||
try {
|
||||
vectorData =
|
||||
|
@ -155,15 +156,13 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
|
|||
throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
|
||||
}
|
||||
FieldEntry fieldEntry = FieldEntry.create(meta, info);
|
||||
fields.put(info.name, fieldEntry);
|
||||
fields.put(info.number, fieldEntry);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
return Lucene99FlatVectorsReader.SHALLOW_SIZE
|
||||
+ RamUsageEstimator.sizeOfMap(
|
||||
fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class));
|
||||
return Lucene99FlatVectorsReader.SHALLOW_SIZE + fields.ramBytesUsed();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -171,21 +170,27 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
|
|||
CodecUtil.checksumEntireFile(vectorData);
|
||||
}
|
||||
|
||||
@Override
|
||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
if (fieldEntry == null) {
|
||||
private FieldEntry getFieldEntry(String field, VectorEncoding expectedEncoding) {
|
||||
final FieldInfo info = fieldInfos.fieldInfo(field);
|
||||
final FieldEntry fieldEntry;
|
||||
if (info == null || (fieldEntry = fields.get(info.number)) == null) {
|
||||
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
||||
}
|
||||
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
||||
if (fieldEntry.vectorEncoding != expectedEncoding) {
|
||||
throw new IllegalArgumentException(
|
||||
"field=\""
|
||||
+ field
|
||||
+ "\" is encoded as: "
|
||||
+ fieldEntry.vectorEncoding
|
||||
+ " expected: "
|
||||
+ VectorEncoding.FLOAT32);
|
||||
+ expectedEncoding);
|
||||
}
|
||||
return fieldEntry;
|
||||
}
|
||||
|
||||
@Override
|
||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
|
||||
return OffHeapFloatVectorValues.load(
|
||||
fieldEntry.similarityFunction,
|
||||
vectorScorer,
|
||||
|
@ -199,19 +204,7 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
|
|||
|
||||
@Override
|
||||
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
if (fieldEntry == null) {
|
||||
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
||||
}
|
||||
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
||||
throw new IllegalArgumentException(
|
||||
"field=\""
|
||||
+ field
|
||||
+ "\" is encoded as: "
|
||||
+ fieldEntry.vectorEncoding
|
||||
+ " expected: "
|
||||
+ VectorEncoding.BYTE);
|
||||
}
|
||||
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
|
||||
return OffHeapByteVectorValues.load(
|
||||
fieldEntry.similarityFunction,
|
||||
vectorScorer,
|
||||
|
@ -225,10 +218,7 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
|
|||
|
||||
@Override
|
||||
public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
||||
return null;
|
||||
}
|
||||
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
|
||||
return vectorScorer.getRandomVectorScorer(
|
||||
fieldEntry.similarityFunction,
|
||||
OffHeapFloatVectorValues.load(
|
||||
|
@ -245,10 +235,7 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
|
|||
|
||||
@Override
|
||||
public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
||||
return null;
|
||||
}
|
||||
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
|
||||
return vectorScorer.getRandomVectorScorer(
|
||||
fieldEntry.similarityFunction,
|
||||
OffHeapByteVectorValues.load(
|
||||
|
|
|
@ -21,9 +21,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
|
||||
|
@ -37,6 +35,7 @@ import org.apache.lucene.index.IndexFileNames;
|
|||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.internal.hppc.IntObjectHashMap;
|
||||
import org.apache.lucene.search.KnnCollector;
|
||||
import org.apache.lucene.store.ChecksumIndexInput;
|
||||
import org.apache.lucene.store.DataInput;
|
||||
|
@ -70,7 +69,7 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
|
|||
|
||||
private final FlatVectorsReader flatVectorsReader;
|
||||
private final FieldInfos fieldInfos;
|
||||
private final Map<String, FieldEntry> fields = new HashMap<>();
|
||||
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
|
||||
private final IndexInput vectorIndex;
|
||||
|
||||
public Lucene99HnswVectorsReader(SegmentReadState state, FlatVectorsReader flatVectorsReader)
|
||||
|
@ -162,7 +161,7 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
|
|||
}
|
||||
FieldEntry fieldEntry = readField(meta, info);
|
||||
validateFieldEntry(info, fieldEntry);
|
||||
fields.put(info.name, fieldEntry);
|
||||
fields.put(info.number, fieldEntry);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -225,8 +224,7 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
|
|||
@Override
|
||||
public long ramBytesUsed() {
|
||||
return Lucene99HnswVectorsReader.SHALLOW_SIZE
|
||||
+ RamUsageEstimator.sizeOfMap(
|
||||
fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class))
|
||||
+ fields.ramBytesUsed()
|
||||
+ flatVectorsReader.ramBytesUsed();
|
||||
}
|
||||
|
||||
|
@ -246,25 +244,43 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
|
|||
return flatVectorsReader.getByteVectorValues(field);
|
||||
}
|
||||
|
||||
private FieldEntry getFieldEntry(String field, VectorEncoding expectedEncoding) {
|
||||
final FieldInfo info = fieldInfos.fieldInfo(field);
|
||||
final FieldEntry fieldEntry;
|
||||
if (info == null || (fieldEntry = fields.get(info.number)) == null) {
|
||||
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
||||
}
|
||||
if (fieldEntry.vectorEncoding != expectedEncoding) {
|
||||
throw new IllegalArgumentException(
|
||||
"field=\""
|
||||
+ field
|
||||
+ "\" is encoded as: "
|
||||
+ fieldEntry.vectorEncoding
|
||||
+ " expected: "
|
||||
+ expectedEncoding);
|
||||
}
|
||||
return fieldEntry;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
|
||||
throws IOException {
|
||||
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
|
||||
search(
|
||||
fields.get(field),
|
||||
fieldEntry,
|
||||
knnCollector,
|
||||
acceptDocs,
|
||||
VectorEncoding.FLOAT32,
|
||||
() -> flatVectorsReader.getRandomVectorScorer(field, target));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
|
||||
throws IOException {
|
||||
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
|
||||
search(
|
||||
fields.get(field),
|
||||
fieldEntry,
|
||||
knnCollector,
|
||||
acceptDocs,
|
||||
VectorEncoding.BYTE,
|
||||
() -> flatVectorsReader.getRandomVectorScorer(field, target));
|
||||
}
|
||||
|
||||
|
@ -272,13 +288,10 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
|
|||
FieldEntry fieldEntry,
|
||||
KnnCollector knnCollector,
|
||||
Bits acceptDocs,
|
||||
VectorEncoding vectorEncoding,
|
||||
IOSupplier<RandomVectorScorer> scorerSupplier)
|
||||
throws IOException {
|
||||
|
||||
if (fieldEntry.size() == 0
|
||||
|| knnCollector.k() == 0
|
||||
|| fieldEntry.vectorEncoding != vectorEncoding) {
|
||||
if (fieldEntry.size() == 0 || knnCollector.k() == 0) {
|
||||
return;
|
||||
}
|
||||
final RandomVectorScorer scorer = scorerSupplier.get();
|
||||
|
@ -304,12 +317,12 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
|
|||
|
||||
@Override
|
||||
public HnswGraph getGraph(String field) throws IOException {
|
||||
FieldInfo info = fieldInfos.fieldInfo(field);
|
||||
if (info == null) {
|
||||
throw new IllegalArgumentException("No such field '" + field + "'");
|
||||
final FieldInfo info = fieldInfos.fieldInfo(field);
|
||||
final FieldEntry entry;
|
||||
if (info == null || (entry = fields.get(info.number)) == null) {
|
||||
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
||||
}
|
||||
FieldEntry entry = fields.get(field);
|
||||
if (entry != null && entry.vectorIndexLength > 0) {
|
||||
if (entry.vectorIndexLength > 0) {
|
||||
return getGraph(entry);
|
||||
} else {
|
||||
return HnswGraph.EMPTY;
|
||||
|
|
|
@ -21,8 +21,6 @@ import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSi
|
|||
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
|
||||
|
@ -36,6 +34,7 @@ import org.apache.lucene.index.IndexFileNames;
|
|||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.internal.hppc.IntObjectHashMap;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
import org.apache.lucene.store.ChecksumIndexInput;
|
||||
import org.apache.lucene.store.IOContext;
|
||||
|
@ -59,15 +58,17 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
|
|||
private static final long SHALLOW_SIZE =
|
||||
RamUsageEstimator.shallowSizeOfInstance(Lucene99ScalarQuantizedVectorsReader.class);
|
||||
|
||||
private final Map<String, FieldEntry> fields = new HashMap<>();
|
||||
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
|
||||
private final IndexInput quantizedVectorData;
|
||||
private final FlatVectorsReader rawVectorsReader;
|
||||
private final FieldInfos fieldInfos;
|
||||
|
||||
public Lucene99ScalarQuantizedVectorsReader(
|
||||
SegmentReadState state, FlatVectorsReader rawVectorsReader, FlatVectorsScorer scorer)
|
||||
throws IOException {
|
||||
super(scorer);
|
||||
this.rawVectorsReader = rawVectorsReader;
|
||||
this.fieldInfos = state.fieldInfos;
|
||||
int versionMeta = -1;
|
||||
String metaFileName =
|
||||
IndexFileNames.segmentFileName(
|
||||
|
@ -118,7 +119,7 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
|
|||
}
|
||||
FieldEntry fieldEntry = readField(meta, versionMeta, info);
|
||||
validateFieldEntry(info, fieldEntry);
|
||||
fields.put(info.name, fieldEntry);
|
||||
fields.put(info.number, fieldEntry);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -163,10 +164,10 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
|
|||
CodecUtil.checksumEntireFile(quantizedVectorData);
|
||||
}
|
||||
|
||||
@Override
|
||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
if (fieldEntry == null) {
|
||||
private FieldEntry getFieldEntry(String field) {
|
||||
final FieldInfo info = fieldInfos.fieldInfo(field);
|
||||
final FieldEntry fieldEntry;
|
||||
if (info == null || (fieldEntry = fields.get(info.number)) == null) {
|
||||
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
||||
}
|
||||
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
||||
|
@ -178,6 +179,12 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
|
|||
+ " expected: "
|
||||
+ VectorEncoding.FLOAT32);
|
||||
}
|
||||
return fieldEntry;
|
||||
}
|
||||
|
||||
@Override
|
||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||
final FieldEntry fieldEntry = getFieldEntry(field);
|
||||
final FloatVectorValues rawVectorValues = rawVectorsReader.getFloatVectorValues(field);
|
||||
OffHeapQuantizedByteVectorValues quantizedByteVectorValues =
|
||||
OffHeapQuantizedByteVectorValues.load(
|
||||
|
@ -241,10 +248,7 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
|
|||
|
||||
@Override
|
||||
public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
||||
return null;
|
||||
}
|
||||
final FieldEntry fieldEntry = getFieldEntry(field);
|
||||
if (fieldEntry.scalarQuantizer == null) {
|
||||
return rawVectorsReader.getRandomVectorScorer(field, target);
|
||||
}
|
||||
|
@ -275,12 +279,7 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
|
|||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
long size = SHALLOW_SIZE;
|
||||
size +=
|
||||
RamUsageEstimator.sizeOfMap(
|
||||
fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class));
|
||||
size += rawVectorsReader.ramBytesUsed();
|
||||
return size;
|
||||
return SHALLOW_SIZE + fields.ramBytesUsed() + rawVectorsReader.ramBytesUsed();
|
||||
}
|
||||
|
||||
private FieldEntry readField(IndexInput input, int versionMeta, FieldInfo info)
|
||||
|
@ -301,11 +300,8 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
|
|||
}
|
||||
|
||||
@Override
|
||||
public QuantizedByteVectorValues getQuantizedVectorValues(String fieldName) throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(fieldName);
|
||||
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
||||
return null;
|
||||
}
|
||||
public QuantizedByteVectorValues getQuantizedVectorValues(String field) throws IOException {
|
||||
final FieldEntry fieldEntry = getFieldEntry(field);
|
||||
return OffHeapQuantizedByteVectorValues.load(
|
||||
fieldEntry.ordToDoc,
|
||||
fieldEntry.dimension,
|
||||
|
@ -320,11 +316,8 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
|
|||
}
|
||||
|
||||
@Override
|
||||
public ScalarQuantizer getQuantizationState(String fieldName) {
|
||||
FieldEntry fieldEntry = fields.get(fieldName);
|
||||
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
||||
return null;
|
||||
}
|
||||
public ScalarQuantizer getQuantizationState(String field) {
|
||||
final FieldEntry fieldEntry = getFieldEntry(field);
|
||||
return fieldEntry.scalarQuantizer;
|
||||
}
|
||||
|
||||
|
|
|
@ -19,7 +19,9 @@ package org.apache.lucene.codecs.perfield;
|
|||
|
||||
import java.io.Closeable;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.ServiceLoader;
|
||||
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||
|
@ -28,11 +30,14 @@ import org.apache.lucene.codecs.KnnVectorsReader;
|
|||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FieldInfos;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.MergeState;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.Sorter;
|
||||
import org.apache.lucene.internal.hppc.IntObjectHashMap;
|
||||
import org.apache.lucene.internal.hppc.ObjectCursor;
|
||||
import org.apache.lucene.search.KnnCollector;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
|
@ -186,7 +191,8 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
|
|||
/** VectorReader that can wrap multiple delegate readers, selected by field. */
|
||||
public static class FieldsReader extends KnnVectorsReader {
|
||||
|
||||
private final Map<String, KnnVectorsReader> fields = new HashMap<>();
|
||||
private final IntObjectHashMap<KnnVectorsReader> fields = new IntObjectHashMap<>();
|
||||
private final FieldInfos fieldInfos;
|
||||
|
||||
/**
|
||||
* Create a FieldsReader over a segment, opening VectorReaders for each KnnVectorsFormat
|
||||
|
@ -196,7 +202,7 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
|
|||
* @throws IOException if one of the delegate readers throws
|
||||
*/
|
||||
public FieldsReader(final SegmentReadState readState) throws IOException {
|
||||
|
||||
this.fieldInfos = readState.fieldInfos;
|
||||
// Init each unique format:
|
||||
boolean success = false;
|
||||
Map<String, KnnVectorsReader> formats = new HashMap<>();
|
||||
|
@ -221,7 +227,7 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
|
|||
segmentSuffix,
|
||||
format.fieldsReader(new SegmentReadState(readState, segmentSuffix)));
|
||||
}
|
||||
fields.put(fieldName, formats.get(segmentSuffix));
|
||||
fields.put(fi.number, formats.get(segmentSuffix));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -239,51 +245,69 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
|
|||
* @param field the name of a numeric vector field
|
||||
*/
|
||||
public KnnVectorsReader getFieldReader(String field) {
|
||||
return fields.get(field);
|
||||
final FieldInfo info = fieldInfos.fieldInfo(field);
|
||||
if (info == null) {
|
||||
return null;
|
||||
}
|
||||
return fields.get(info.number);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkIntegrity() throws IOException {
|
||||
for (KnnVectorsReader reader : fields.values()) {
|
||||
reader.checkIntegrity();
|
||||
for (ObjectCursor<KnnVectorsReader> cursor : fields.values()) {
|
||||
cursor.value.checkIntegrity();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||
KnnVectorsReader knnVectorsReader = fields.get(field);
|
||||
if (knnVectorsReader == null) {
|
||||
final FieldInfo info = fieldInfos.fieldInfo(field);
|
||||
final KnnVectorsReader reader;
|
||||
if (info == null || (reader = fields.get(info.number)) == null) {
|
||||
return null;
|
||||
} else {
|
||||
return knnVectorsReader.getFloatVectorValues(field);
|
||||
}
|
||||
return reader.getFloatVectorValues(field);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||
KnnVectorsReader knnVectorsReader = fields.get(field);
|
||||
if (knnVectorsReader == null) {
|
||||
final FieldInfo info = fieldInfos.fieldInfo(field);
|
||||
final KnnVectorsReader reader;
|
||||
if (info == null || (reader = fields.get(info.number)) == null) {
|
||||
return null;
|
||||
} else {
|
||||
return knnVectorsReader.getByteVectorValues(field);
|
||||
}
|
||||
return reader.getByteVectorValues(field);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
|
||||
throws IOException {
|
||||
fields.get(field).search(field, target, knnCollector, acceptDocs);
|
||||
final FieldInfo info = fieldInfos.fieldInfo(field);
|
||||
final KnnVectorsReader reader;
|
||||
if (info == null || (reader = fields.get(info.number)) == null) {
|
||||
return;
|
||||
}
|
||||
reader.search(field, target, knnCollector, acceptDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
|
||||
throws IOException {
|
||||
fields.get(field).search(field, target, knnCollector, acceptDocs);
|
||||
final FieldInfo info = fieldInfos.fieldInfo(field);
|
||||
final KnnVectorsReader reader;
|
||||
if (info == null || (reader = fields.get(info.number)) == null) {
|
||||
return;
|
||||
}
|
||||
reader.search(field, target, knnCollector, acceptDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() throws IOException {
|
||||
IOUtils.close(fields.values());
|
||||
List<KnnVectorsReader> readers = new ArrayList<>(fields.size());
|
||||
for (ObjectCursor<KnnVectorsReader> cursor : fields.values()) {
|
||||
readers.add(cursor.value);
|
||||
}
|
||||
IOUtils.close(readers);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue