mirror of https://github.com/apache/lucene.git
Replace Map<String,Object> with IntObjectHashMap for KnnVectorsReader (#13763)
This commit is contained in:
parent
cff28d546f
commit
584387a254
|
@ -56,6 +56,8 @@ Optimizations
|
||||||
|
|
||||||
* GITHUB#13958: Speed up advancing within a block. (Adrien Grand)
|
* GITHUB#13958: Speed up advancing within a block. (Adrien Grand)
|
||||||
|
|
||||||
|
* GITHUB#13763: Replace Map<String,Object> with IntObjectHashMap for KnnVectorsReader (Pan Guixin)
|
||||||
|
|
||||||
Bug Fixes
|
Bug Fixes
|
||||||
---------------------
|
---------------------
|
||||||
* GITHUB#13832: Fixed an issue where the DefaultPassageFormatter.format method did not format passages as intended
|
* GITHUB#13832: Fixed an issue where the DefaultPassageFormatter.format method did not format passages as intended
|
||||||
|
|
|
@ -20,8 +20,6 @@ package org.apache.lucene.backward_codecs.lucene90;
|
||||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.SplittableRandom;
|
import java.util.SplittableRandom;
|
||||||
import org.apache.lucene.codecs.CodecUtil;
|
import org.apache.lucene.codecs.CodecUtil;
|
||||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||||
|
@ -33,6 +31,7 @@ import org.apache.lucene.index.FloatVectorValues;
|
||||||
import org.apache.lucene.index.IndexFileNames;
|
import org.apache.lucene.index.IndexFileNames;
|
||||||
import org.apache.lucene.index.SegmentReadState;
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
|
import org.apache.lucene.internal.hppc.IntObjectHashMap;
|
||||||
import org.apache.lucene.search.KnnCollector;
|
import org.apache.lucene.search.KnnCollector;
|
||||||
import org.apache.lucene.search.VectorScorer;
|
import org.apache.lucene.search.VectorScorer;
|
||||||
import org.apache.lucene.store.ChecksumIndexInput;
|
import org.apache.lucene.store.ChecksumIndexInput;
|
||||||
|
@ -50,14 +49,16 @@ import org.apache.lucene.util.hnsw.NeighborQueue;
|
||||||
*/
|
*/
|
||||||
public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
||||||
|
|
||||||
private final Map<String, FieldEntry> fields = new HashMap<>();
|
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
|
||||||
private final IndexInput vectorData;
|
private final IndexInput vectorData;
|
||||||
private final IndexInput vectorIndex;
|
private final IndexInput vectorIndex;
|
||||||
private final long checksumSeed;
|
private final long checksumSeed;
|
||||||
|
private final FieldInfos fieldInfos;
|
||||||
|
|
||||||
Lucene90HnswVectorsReader(SegmentReadState state) throws IOException {
|
Lucene90HnswVectorsReader(SegmentReadState state) throws IOException {
|
||||||
int versionMeta = readMetadata(state);
|
int versionMeta = readMetadata(state);
|
||||||
long[] checksumRef = new long[1];
|
long[] checksumRef = new long[1];
|
||||||
|
this.fieldInfos = state.fieldInfos;
|
||||||
boolean success = false;
|
boolean success = false;
|
||||||
try {
|
try {
|
||||||
vectorData =
|
vectorData =
|
||||||
|
@ -158,7 +159,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
||||||
|
|
||||||
FieldEntry fieldEntry = readField(meta, info);
|
FieldEntry fieldEntry = readField(meta, info);
|
||||||
validateFieldEntry(info, fieldEntry);
|
validateFieldEntry(info, fieldEntry);
|
||||||
fields.put(info.name, fieldEntry);
|
fields.put(info.number, fieldEntry);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -218,13 +219,18 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
||||||
CodecUtil.checksumEntireFile(vectorIndex);
|
CodecUtil.checksumEntireFile(vectorIndex);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
private FieldEntry getFieldEntry(String field) {
|
||||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
final FieldInfo info = fieldInfos.fieldInfo(field);
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
final FieldEntry fieldEntry;
|
||||||
if (fieldEntry == null) {
|
if (info == null || (fieldEntry = fields.get(info.number)) == null) {
|
||||||
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
||||||
}
|
}
|
||||||
return getOffHeapVectorValues(fieldEntry);
|
return fieldEntry;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||||
|
return getOffHeapVectorValues(getFieldEntry(field));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -235,8 +241,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
||||||
@Override
|
@Override
|
||||||
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
|
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
final FieldEntry fieldEntry = getFieldEntry(field);
|
||||||
|
|
||||||
if (fieldEntry.size() == 0) {
|
if (fieldEntry.size() == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,8 +21,6 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.function.IntUnaryOperator;
|
import java.util.function.IntUnaryOperator;
|
||||||
import org.apache.lucene.codecs.CodecUtil;
|
import org.apache.lucene.codecs.CodecUtil;
|
||||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||||
|
@ -35,6 +33,7 @@ import org.apache.lucene.index.FloatVectorValues;
|
||||||
import org.apache.lucene.index.IndexFileNames;
|
import org.apache.lucene.index.IndexFileNames;
|
||||||
import org.apache.lucene.index.SegmentReadState;
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
|
import org.apache.lucene.internal.hppc.IntObjectHashMap;
|
||||||
import org.apache.lucene.search.DocIdSetIterator;
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
import org.apache.lucene.search.KnnCollector;
|
import org.apache.lucene.search.KnnCollector;
|
||||||
import org.apache.lucene.search.VectorScorer;
|
import org.apache.lucene.search.VectorScorer;
|
||||||
|
@ -55,13 +54,15 @@ import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||||
*/
|
*/
|
||||||
public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
||||||
|
|
||||||
private final Map<String, FieldEntry> fields = new HashMap<>();
|
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
|
||||||
private final IndexInput vectorData;
|
private final IndexInput vectorData;
|
||||||
private final IndexInput vectorIndex;
|
private final IndexInput vectorIndex;
|
||||||
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
|
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
|
||||||
|
private final FieldInfos fieldInfos;
|
||||||
|
|
||||||
Lucene91HnswVectorsReader(SegmentReadState state) throws IOException {
|
Lucene91HnswVectorsReader(SegmentReadState state) throws IOException {
|
||||||
int versionMeta = readMetadata(state);
|
int versionMeta = readMetadata(state);
|
||||||
|
this.fieldInfos = state.fieldInfos;
|
||||||
boolean success = false;
|
boolean success = false;
|
||||||
try {
|
try {
|
||||||
vectorData =
|
vectorData =
|
||||||
|
@ -154,7 +155,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
||||||
}
|
}
|
||||||
FieldEntry fieldEntry = readField(meta, info);
|
FieldEntry fieldEntry = readField(meta, info);
|
||||||
validateFieldEntry(info, fieldEntry);
|
validateFieldEntry(info, fieldEntry);
|
||||||
fields.put(info.name, fieldEntry);
|
fields.put(info.number, fieldEntry);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -214,13 +215,18 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
||||||
CodecUtil.checksumEntireFile(vectorIndex);
|
CodecUtil.checksumEntireFile(vectorIndex);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
private FieldEntry getFieldEntry(String field) {
|
||||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
final FieldInfo info = fieldInfos.fieldInfo(field);
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
final FieldEntry fieldEntry;
|
||||||
if (fieldEntry == null) {
|
if (info == null || (fieldEntry = fields.get(info.number)) == null) {
|
||||||
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
||||||
}
|
}
|
||||||
return getOffHeapVectorValues(fieldEntry);
|
return fieldEntry;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||||
|
return getOffHeapVectorValues(getFieldEntry(field));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -231,8 +237,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
||||||
@Override
|
@Override
|
||||||
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
|
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
final FieldEntry fieldEntry = getFieldEntry(field);
|
||||||
|
|
||||||
if (fieldEntry.size() == 0) {
|
if (fieldEntry.size() == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,8 +21,6 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.Map;
|
|
||||||
import org.apache.lucene.codecs.CodecUtil;
|
import org.apache.lucene.codecs.CodecUtil;
|
||||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||||
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
|
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
|
||||||
|
@ -34,6 +32,7 @@ import org.apache.lucene.index.FloatVectorValues;
|
||||||
import org.apache.lucene.index.IndexFileNames;
|
import org.apache.lucene.index.IndexFileNames;
|
||||||
import org.apache.lucene.index.SegmentReadState;
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
|
import org.apache.lucene.internal.hppc.IntObjectHashMap;
|
||||||
import org.apache.lucene.search.KnnCollector;
|
import org.apache.lucene.search.KnnCollector;
|
||||||
import org.apache.lucene.store.ChecksumIndexInput;
|
import org.apache.lucene.store.ChecksumIndexInput;
|
||||||
import org.apache.lucene.store.DataInput;
|
import org.apache.lucene.store.DataInput;
|
||||||
|
@ -53,13 +52,15 @@ import org.apache.lucene.util.packed.DirectMonotonicReader;
|
||||||
*/
|
*/
|
||||||
public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
|
public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
|
||||||
|
|
||||||
private final Map<String, FieldEntry> fields = new HashMap<>();
|
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
|
||||||
private final IndexInput vectorData;
|
private final IndexInput vectorData;
|
||||||
private final IndexInput vectorIndex;
|
private final IndexInput vectorIndex;
|
||||||
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
|
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
|
||||||
|
private final FieldInfos fieldInfos;
|
||||||
|
|
||||||
Lucene92HnswVectorsReader(SegmentReadState state) throws IOException {
|
Lucene92HnswVectorsReader(SegmentReadState state) throws IOException {
|
||||||
int versionMeta = readMetadata(state);
|
int versionMeta = readMetadata(state);
|
||||||
|
this.fieldInfos = state.fieldInfos;
|
||||||
boolean success = false;
|
boolean success = false;
|
||||||
try {
|
try {
|
||||||
vectorData =
|
vectorData =
|
||||||
|
@ -152,7 +153,7 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
|
||||||
}
|
}
|
||||||
FieldEntry fieldEntry = readField(meta, info);
|
FieldEntry fieldEntry = readField(meta, info);
|
||||||
validateFieldEntry(info, fieldEntry);
|
validateFieldEntry(info, fieldEntry);
|
||||||
fields.put(info.name, fieldEntry);
|
fields.put(info.number, fieldEntry);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -212,13 +213,18 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
|
||||||
CodecUtil.checksumEntireFile(vectorIndex);
|
CodecUtil.checksumEntireFile(vectorIndex);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
private FieldEntry getFieldEntry(String field) {
|
||||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
final FieldInfo info = fieldInfos.fieldInfo(field);
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
final FieldEntry fieldEntry;
|
||||||
if (fieldEntry == null) {
|
if (info == null || (fieldEntry = fields.get(info.number)) == null) {
|
||||||
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
||||||
}
|
}
|
||||||
return OffHeapFloatVectorValues.load(fieldEntry, vectorData);
|
return fieldEntry;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||||
|
return OffHeapFloatVectorValues.load(getFieldEntry(field), vectorData);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -229,8 +235,7 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
|
||||||
@Override
|
@Override
|
||||||
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
|
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
final FieldEntry fieldEntry = getFieldEntry(field);
|
||||||
|
|
||||||
if (fieldEntry.size() == 0) {
|
if (fieldEntry.size() == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,8 +21,6 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.Map;
|
|
||||||
import org.apache.lucene.codecs.CodecUtil;
|
import org.apache.lucene.codecs.CodecUtil;
|
||||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||||
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
|
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
|
||||||
|
@ -35,6 +33,7 @@ import org.apache.lucene.index.IndexFileNames;
|
||||||
import org.apache.lucene.index.SegmentReadState;
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
import org.apache.lucene.index.VectorEncoding;
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
|
import org.apache.lucene.internal.hppc.IntObjectHashMap;
|
||||||
import org.apache.lucene.search.KnnCollector;
|
import org.apache.lucene.search.KnnCollector;
|
||||||
import org.apache.lucene.store.ChecksumIndexInput;
|
import org.apache.lucene.store.ChecksumIndexInput;
|
||||||
import org.apache.lucene.store.DataInput;
|
import org.apache.lucene.store.DataInput;
|
||||||
|
@ -54,13 +53,15 @@ import org.apache.lucene.util.packed.DirectMonotonicReader;
|
||||||
*/
|
*/
|
||||||
public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
||||||
|
|
||||||
private final Map<String, FieldEntry> fields = new HashMap<>();
|
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
|
||||||
private final IndexInput vectorData;
|
private final IndexInput vectorData;
|
||||||
private final IndexInput vectorIndex;
|
private final IndexInput vectorIndex;
|
||||||
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
|
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
|
||||||
|
private final FieldInfos fieldInfos;
|
||||||
|
|
||||||
Lucene94HnswVectorsReader(SegmentReadState state) throws IOException {
|
Lucene94HnswVectorsReader(SegmentReadState state) throws IOException {
|
||||||
int versionMeta = readMetadata(state);
|
int versionMeta = readMetadata(state);
|
||||||
|
this.fieldInfos = state.fieldInfos;
|
||||||
boolean success = false;
|
boolean success = false;
|
||||||
try {
|
try {
|
||||||
vectorData =
|
vectorData =
|
||||||
|
@ -153,7 +154,7 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
||||||
}
|
}
|
||||||
FieldEntry fieldEntry = readField(meta, info);
|
FieldEntry fieldEntry = readField(meta, info);
|
||||||
validateFieldEntry(info, fieldEntry);
|
validateFieldEntry(info, fieldEntry);
|
||||||
fields.put(info.name, fieldEntry);
|
fields.put(info.number, fieldEntry);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -230,48 +231,41 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
||||||
CodecUtil.checksumEntireFile(vectorIndex);
|
CodecUtil.checksumEntireFile(vectorIndex);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
private FieldEntry getFieldEntry(String field, VectorEncoding expectedEncoding) {
|
||||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
final FieldInfo info = fieldInfos.fieldInfo(field);
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
final FieldEntry fieldEntry;
|
||||||
if (fieldEntry == null) {
|
if (info == null || (fieldEntry = fields.get(info.number)) == null) {
|
||||||
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
||||||
}
|
}
|
||||||
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
if (fieldEntry.vectorEncoding != expectedEncoding) {
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
"field=\""
|
"field=\""
|
||||||
+ field
|
+ field
|
||||||
+ "\" is encoded as: "
|
+ "\" is encoded as: "
|
||||||
+ fieldEntry.vectorEncoding
|
+ fieldEntry.vectorEncoding
|
||||||
+ " expected: "
|
+ " expected: "
|
||||||
+ VectorEncoding.FLOAT32);
|
+ expectedEncoding);
|
||||||
}
|
}
|
||||||
|
return fieldEntry;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||||
|
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
|
||||||
return OffHeapFloatVectorValues.load(fieldEntry, vectorData);
|
return OffHeapFloatVectorValues.load(fieldEntry, vectorData);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
|
||||||
if (fieldEntry == null) {
|
|
||||||
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
|
||||||
}
|
|
||||||
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
|
||||||
throw new IllegalArgumentException(
|
|
||||||
"field=\""
|
|
||||||
+ field
|
|
||||||
+ "\" is encoded as: "
|
|
||||||
+ fieldEntry.vectorEncoding
|
|
||||||
+ " expected: "
|
|
||||||
+ VectorEncoding.BYTE);
|
|
||||||
}
|
|
||||||
return OffHeapByteVectorValues.load(fieldEntry, vectorData);
|
return OffHeapByteVectorValues.load(fieldEntry, vectorData);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
|
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
|
||||||
|
if (fieldEntry.size() == 0 || knnCollector.k() == 0) {
|
||||||
if (fieldEntry.size() == 0 || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -289,9 +283,8 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
||||||
@Override
|
@Override
|
||||||
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
|
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
|
||||||
|
if (fieldEntry.size() == 0 || knnCollector.k() == 0) {
|
||||||
if (fieldEntry.size() == 0 || fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -21,8 +21,6 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.Map;
|
|
||||||
import org.apache.lucene.codecs.CodecUtil;
|
import org.apache.lucene.codecs.CodecUtil;
|
||||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||||
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
|
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
|
||||||
|
@ -39,6 +37,7 @@ import org.apache.lucene.index.IndexFileNames;
|
||||||
import org.apache.lucene.index.SegmentReadState;
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
import org.apache.lucene.index.VectorEncoding;
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
|
import org.apache.lucene.internal.hppc.IntObjectHashMap;
|
||||||
import org.apache.lucene.search.KnnCollector;
|
import org.apache.lucene.search.KnnCollector;
|
||||||
import org.apache.lucene.store.ChecksumIndexInput;
|
import org.apache.lucene.store.ChecksumIndexInput;
|
||||||
import org.apache.lucene.store.DataInput;
|
import org.apache.lucene.store.DataInput;
|
||||||
|
@ -61,7 +60,7 @@ import org.apache.lucene.util.packed.DirectMonotonicReader;
|
||||||
public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements HnswGraphProvider {
|
public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements HnswGraphProvider {
|
||||||
|
|
||||||
private final FieldInfos fieldInfos;
|
private final FieldInfos fieldInfos;
|
||||||
private final Map<String, FieldEntry> fields = new HashMap<>();
|
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
|
||||||
private final IndexInput vectorData;
|
private final IndexInput vectorData;
|
||||||
private final IndexInput vectorIndex;
|
private final IndexInput vectorIndex;
|
||||||
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
|
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
|
||||||
|
@ -161,7 +160,7 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
|
||||||
}
|
}
|
||||||
FieldEntry fieldEntry = readField(meta, info);
|
FieldEntry fieldEntry = readField(meta, info);
|
||||||
validateFieldEntry(info, fieldEntry);
|
validateFieldEntry(info, fieldEntry);
|
||||||
fields.put(info.name, fieldEntry);
|
fields.put(info.number, fieldEntry);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -238,21 +237,27 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
|
||||||
CodecUtil.checksumEntireFile(vectorIndex);
|
CodecUtil.checksumEntireFile(vectorIndex);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
private FieldEntry getFieldEntry(String field, VectorEncoding expectedEncoding) {
|
||||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
final FieldInfo info = fieldInfos.fieldInfo(field);
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
final FieldEntry fieldEntry;
|
||||||
if (fieldEntry == null) {
|
if (info == null || (fieldEntry = fields.get(info.number)) == null) {
|
||||||
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
||||||
}
|
}
|
||||||
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
if (fieldEntry.vectorEncoding != expectedEncoding) {
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
"field=\""
|
"field=\""
|
||||||
+ field
|
+ field
|
||||||
+ "\" is encoded as: "
|
+ "\" is encoded as: "
|
||||||
+ fieldEntry.vectorEncoding
|
+ fieldEntry.vectorEncoding
|
||||||
+ " expected: "
|
+ " expected: "
|
||||||
+ VectorEncoding.FLOAT32);
|
+ expectedEncoding);
|
||||||
}
|
}
|
||||||
|
return fieldEntry;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||||
|
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
|
||||||
return OffHeapFloatVectorValues.load(
|
return OffHeapFloatVectorValues.load(
|
||||||
fieldEntry.similarityFunction,
|
fieldEntry.similarityFunction,
|
||||||
defaultFlatVectorScorer,
|
defaultFlatVectorScorer,
|
||||||
|
@ -266,19 +271,7 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
|
||||||
if (fieldEntry == null) {
|
|
||||||
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
|
||||||
}
|
|
||||||
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
|
||||||
throw new IllegalArgumentException(
|
|
||||||
"field=\""
|
|
||||||
+ field
|
|
||||||
+ "\" is encoded as: "
|
|
||||||
+ fieldEntry.vectorEncoding
|
|
||||||
+ " expected: "
|
|
||||||
+ VectorEncoding.BYTE);
|
|
||||||
}
|
|
||||||
return OffHeapByteVectorValues.load(
|
return OffHeapByteVectorValues.load(
|
||||||
fieldEntry.similarityFunction,
|
fieldEntry.similarityFunction,
|
||||||
defaultFlatVectorScorer,
|
defaultFlatVectorScorer,
|
||||||
|
@ -293,11 +286,8 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
|
||||||
@Override
|
@Override
|
||||||
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
|
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
|
||||||
|
if (fieldEntry.size() == 0 || knnCollector.k() == 0) {
|
||||||
if (fieldEntry.size() == 0
|
|
||||||
|| knnCollector.k() == 0
|
|
||||||
|| fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -324,11 +314,8 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
|
||||||
@Override
|
@Override
|
||||||
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
|
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
|
||||||
|
if (fieldEntry.size() == 0 || knnCollector.k() == 0) {
|
||||||
if (fieldEntry.size() == 0
|
|
||||||
|| knnCollector.k() == 0
|
|
||||||
|| fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -355,12 +342,12 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
|
||||||
/** Get knn graph values; used for testing */
|
/** Get knn graph values; used for testing */
|
||||||
@Override
|
@Override
|
||||||
public HnswGraph getGraph(String field) throws IOException {
|
public HnswGraph getGraph(String field) throws IOException {
|
||||||
FieldInfo info = fieldInfos.fieldInfo(field);
|
final FieldInfo info = fieldInfos.fieldInfo(field);
|
||||||
if (info == null) {
|
final FieldEntry entry;
|
||||||
throw new IllegalArgumentException("No such field '" + field + "'");
|
if (info == null || (entry = fields.get(info.number)) == null) {
|
||||||
|
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
||||||
}
|
}
|
||||||
FieldEntry entry = fields.get(field);
|
if (entry.vectorIndexLength > 0) {
|
||||||
if (entry != null && entry.vectorIndexLength > 0) {
|
|
||||||
return getGraph(entry);
|
return getGraph(entry);
|
||||||
} else {
|
} else {
|
||||||
return HnswGraph.EMPTY;
|
return HnswGraph.EMPTY;
|
||||||
|
|
|
@ -26,8 +26,6 @@ import static org.apache.lucene.codecs.simpletext.SimpleTextKnnVectorsWriter.VEC
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.Map;
|
|
||||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||||
import org.apache.lucene.index.ByteVectorValues;
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
import org.apache.lucene.index.CorruptIndexException;
|
import org.apache.lucene.index.CorruptIndexException;
|
||||||
|
@ -36,6 +34,7 @@ import org.apache.lucene.index.FloatVectorValues;
|
||||||
import org.apache.lucene.index.IndexFileNames;
|
import org.apache.lucene.index.IndexFileNames;
|
||||||
import org.apache.lucene.index.SegmentReadState;
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
|
import org.apache.lucene.internal.hppc.IntObjectHashMap;
|
||||||
import org.apache.lucene.search.DocIdSetIterator;
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
import org.apache.lucene.search.KnnCollector;
|
import org.apache.lucene.search.KnnCollector;
|
||||||
import org.apache.lucene.search.VectorScorer;
|
import org.apache.lucene.search.VectorScorer;
|
||||||
|
@ -63,7 +62,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
||||||
private final SegmentReadState readState;
|
private final SegmentReadState readState;
|
||||||
private final IndexInput dataIn;
|
private final IndexInput dataIn;
|
||||||
private final BytesRefBuilder scratch = new BytesRefBuilder();
|
private final BytesRefBuilder scratch = new BytesRefBuilder();
|
||||||
private final Map<String, FieldEntry> fieldEntries = new HashMap<>();
|
private final IntObjectHashMap<FieldEntry> fieldEntries = new IntObjectHashMap<>();
|
||||||
|
|
||||||
SimpleTextKnnVectorsReader(SegmentReadState readState) throws IOException {
|
SimpleTextKnnVectorsReader(SegmentReadState readState) throws IOException {
|
||||||
this.readState = readState;
|
this.readState = readState;
|
||||||
|
@ -91,9 +90,9 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
||||||
for (int i = 0; i < size; i++) {
|
for (int i = 0; i < size; i++) {
|
||||||
docIds[i] = readInt(in, EMPTY);
|
docIds[i] = readInt(in, EMPTY);
|
||||||
}
|
}
|
||||||
assert fieldEntries.containsKey(fieldName) == false;
|
assert fieldEntries.containsKey(fieldNumber) == false;
|
||||||
fieldEntries.put(
|
fieldEntries.put(
|
||||||
fieldName,
|
fieldNumber,
|
||||||
new FieldEntry(
|
new FieldEntry(
|
||||||
dimension,
|
dimension,
|
||||||
vectorDataOffset,
|
vectorDataOffset,
|
||||||
|
@ -126,7 +125,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
||||||
throw new IllegalStateException(
|
throw new IllegalStateException(
|
||||||
"KNN vectors readers should not be called on fields that don't enable KNN vectors");
|
"KNN vectors readers should not be called on fields that don't enable KNN vectors");
|
||||||
}
|
}
|
||||||
FieldEntry fieldEntry = fieldEntries.get(field);
|
FieldEntry fieldEntry = fieldEntries.get(info.number);
|
||||||
if (fieldEntry == null) {
|
if (fieldEntry == null) {
|
||||||
// mirror the handling in Lucene90VectorReader#getVectorValues
|
// mirror the handling in Lucene90VectorReader#getVectorValues
|
||||||
// needed to pass TestSimpleTextKnnVectorsFormat#testDeleteAllVectorDocs
|
// needed to pass TestSimpleTextKnnVectorsFormat#testDeleteAllVectorDocs
|
||||||
|
@ -159,7 +158,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
||||||
throw new IllegalStateException(
|
throw new IllegalStateException(
|
||||||
"KNN vectors readers should not be called on fields that don't enable KNN vectors");
|
"KNN vectors readers should not be called on fields that don't enable KNN vectors");
|
||||||
}
|
}
|
||||||
FieldEntry fieldEntry = fieldEntries.get(field);
|
FieldEntry fieldEntry = fieldEntries.get(info.number);
|
||||||
if (fieldEntry == null) {
|
if (fieldEntry == null) {
|
||||||
// mirror the handling in Lucene90VectorReader#getVectorValues
|
// mirror the handling in Lucene90VectorReader#getVectorValues
|
||||||
// needed to pass TestSimpleTextKnnVectorsFormat#testDeleteAllVectorDocs
|
// needed to pass TestSimpleTextKnnVectorsFormat#testDeleteAllVectorDocs
|
||||||
|
|
|
@ -21,8 +21,6 @@ import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSi
|
||||||
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding;
|
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.Map;
|
|
||||||
import org.apache.lucene.codecs.CodecUtil;
|
import org.apache.lucene.codecs.CodecUtil;
|
||||||
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
|
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
|
||||||
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
|
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
|
||||||
|
@ -38,6 +36,7 @@ import org.apache.lucene.index.IndexFileNames;
|
||||||
import org.apache.lucene.index.SegmentReadState;
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
import org.apache.lucene.index.VectorEncoding;
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
|
import org.apache.lucene.internal.hppc.IntObjectHashMap;
|
||||||
import org.apache.lucene.store.ChecksumIndexInput;
|
import org.apache.lucene.store.ChecksumIndexInput;
|
||||||
import org.apache.lucene.store.IOContext;
|
import org.apache.lucene.store.IOContext;
|
||||||
import org.apache.lucene.store.IndexInput;
|
import org.apache.lucene.store.IndexInput;
|
||||||
|
@ -56,13 +55,15 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
|
||||||
private static final long SHALLOW_SIZE =
|
private static final long SHALLOW_SIZE =
|
||||||
RamUsageEstimator.shallowSizeOfInstance(Lucene99FlatVectorsFormat.class);
|
RamUsageEstimator.shallowSizeOfInstance(Lucene99FlatVectorsFormat.class);
|
||||||
|
|
||||||
private final Map<String, FieldEntry> fields = new HashMap<>();
|
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
|
||||||
private final IndexInput vectorData;
|
private final IndexInput vectorData;
|
||||||
|
private final FieldInfos fieldInfos;
|
||||||
|
|
||||||
public Lucene99FlatVectorsReader(SegmentReadState state, FlatVectorsScorer scorer)
|
public Lucene99FlatVectorsReader(SegmentReadState state, FlatVectorsScorer scorer)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
super(scorer);
|
super(scorer);
|
||||||
int versionMeta = readMetadata(state);
|
int versionMeta = readMetadata(state);
|
||||||
|
this.fieldInfos = state.fieldInfos;
|
||||||
boolean success = false;
|
boolean success = false;
|
||||||
try {
|
try {
|
||||||
vectorData =
|
vectorData =
|
||||||
|
@ -155,15 +156,13 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
|
||||||
throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
|
throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
|
||||||
}
|
}
|
||||||
FieldEntry fieldEntry = FieldEntry.create(meta, info);
|
FieldEntry fieldEntry = FieldEntry.create(meta, info);
|
||||||
fields.put(info.name, fieldEntry);
|
fields.put(info.number, fieldEntry);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long ramBytesUsed() {
|
public long ramBytesUsed() {
|
||||||
return Lucene99FlatVectorsReader.SHALLOW_SIZE
|
return Lucene99FlatVectorsReader.SHALLOW_SIZE + fields.ramBytesUsed();
|
||||||
+ RamUsageEstimator.sizeOfMap(
|
|
||||||
fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -171,21 +170,27 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
|
||||||
CodecUtil.checksumEntireFile(vectorData);
|
CodecUtil.checksumEntireFile(vectorData);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
private FieldEntry getFieldEntry(String field, VectorEncoding expectedEncoding) {
|
||||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
final FieldInfo info = fieldInfos.fieldInfo(field);
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
final FieldEntry fieldEntry;
|
||||||
if (fieldEntry == null) {
|
if (info == null || (fieldEntry = fields.get(info.number)) == null) {
|
||||||
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
||||||
}
|
}
|
||||||
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
if (fieldEntry.vectorEncoding != expectedEncoding) {
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
"field=\""
|
"field=\""
|
||||||
+ field
|
+ field
|
||||||
+ "\" is encoded as: "
|
+ "\" is encoded as: "
|
||||||
+ fieldEntry.vectorEncoding
|
+ fieldEntry.vectorEncoding
|
||||||
+ " expected: "
|
+ " expected: "
|
||||||
+ VectorEncoding.FLOAT32);
|
+ expectedEncoding);
|
||||||
}
|
}
|
||||||
|
return fieldEntry;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||||
|
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
|
||||||
return OffHeapFloatVectorValues.load(
|
return OffHeapFloatVectorValues.load(
|
||||||
fieldEntry.similarityFunction,
|
fieldEntry.similarityFunction,
|
||||||
vectorScorer,
|
vectorScorer,
|
||||||
|
@ -199,19 +204,7 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
|
||||||
if (fieldEntry == null) {
|
|
||||||
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
|
||||||
}
|
|
||||||
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
|
||||||
throw new IllegalArgumentException(
|
|
||||||
"field=\""
|
|
||||||
+ field
|
|
||||||
+ "\" is encoded as: "
|
|
||||||
+ fieldEntry.vectorEncoding
|
|
||||||
+ " expected: "
|
|
||||||
+ VectorEncoding.BYTE);
|
|
||||||
}
|
|
||||||
return OffHeapByteVectorValues.load(
|
return OffHeapByteVectorValues.load(
|
||||||
fieldEntry.similarityFunction,
|
fieldEntry.similarityFunction,
|
||||||
vectorScorer,
|
vectorScorer,
|
||||||
|
@ -225,10 +218,7 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException {
|
public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException {
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
|
||||||
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
return vectorScorer.getRandomVectorScorer(
|
return vectorScorer.getRandomVectorScorer(
|
||||||
fieldEntry.similarityFunction,
|
fieldEntry.similarityFunction,
|
||||||
OffHeapFloatVectorValues.load(
|
OffHeapFloatVectorValues.load(
|
||||||
|
@ -245,10 +235,7 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException {
|
public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException {
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
|
||||||
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
return vectorScorer.getRandomVectorScorer(
|
return vectorScorer.getRandomVectorScorer(
|
||||||
fieldEntry.similarityFunction,
|
fieldEntry.similarityFunction,
|
||||||
OffHeapByteVectorValues.load(
|
OffHeapByteVectorValues.load(
|
||||||
|
|
|
@ -21,9 +21,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
|
||||||
import org.apache.lucene.codecs.CodecUtil;
|
import org.apache.lucene.codecs.CodecUtil;
|
||||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||||
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
|
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
|
||||||
|
@ -37,6 +35,7 @@ import org.apache.lucene.index.IndexFileNames;
|
||||||
import org.apache.lucene.index.SegmentReadState;
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
import org.apache.lucene.index.VectorEncoding;
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
|
import org.apache.lucene.internal.hppc.IntObjectHashMap;
|
||||||
import org.apache.lucene.search.KnnCollector;
|
import org.apache.lucene.search.KnnCollector;
|
||||||
import org.apache.lucene.store.ChecksumIndexInput;
|
import org.apache.lucene.store.ChecksumIndexInput;
|
||||||
import org.apache.lucene.store.DataInput;
|
import org.apache.lucene.store.DataInput;
|
||||||
|
@ -70,7 +69,7 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
|
||||||
|
|
||||||
private final FlatVectorsReader flatVectorsReader;
|
private final FlatVectorsReader flatVectorsReader;
|
||||||
private final FieldInfos fieldInfos;
|
private final FieldInfos fieldInfos;
|
||||||
private final Map<String, FieldEntry> fields = new HashMap<>();
|
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
|
||||||
private final IndexInput vectorIndex;
|
private final IndexInput vectorIndex;
|
||||||
|
|
||||||
public Lucene99HnswVectorsReader(SegmentReadState state, FlatVectorsReader flatVectorsReader)
|
public Lucene99HnswVectorsReader(SegmentReadState state, FlatVectorsReader flatVectorsReader)
|
||||||
|
@ -162,7 +161,7 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
|
||||||
}
|
}
|
||||||
FieldEntry fieldEntry = readField(meta, info);
|
FieldEntry fieldEntry = readField(meta, info);
|
||||||
validateFieldEntry(info, fieldEntry);
|
validateFieldEntry(info, fieldEntry);
|
||||||
fields.put(info.name, fieldEntry);
|
fields.put(info.number, fieldEntry);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -225,8 +224,7 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
|
||||||
@Override
|
@Override
|
||||||
public long ramBytesUsed() {
|
public long ramBytesUsed() {
|
||||||
return Lucene99HnswVectorsReader.SHALLOW_SIZE
|
return Lucene99HnswVectorsReader.SHALLOW_SIZE
|
||||||
+ RamUsageEstimator.sizeOfMap(
|
+ fields.ramBytesUsed()
|
||||||
fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class))
|
|
||||||
+ flatVectorsReader.ramBytesUsed();
|
+ flatVectorsReader.ramBytesUsed();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -246,25 +244,43 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
|
||||||
return flatVectorsReader.getByteVectorValues(field);
|
return flatVectorsReader.getByteVectorValues(field);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private FieldEntry getFieldEntry(String field, VectorEncoding expectedEncoding) {
|
||||||
|
final FieldInfo info = fieldInfos.fieldInfo(field);
|
||||||
|
final FieldEntry fieldEntry;
|
||||||
|
if (info == null || (fieldEntry = fields.get(info.number)) == null) {
|
||||||
|
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
||||||
|
}
|
||||||
|
if (fieldEntry.vectorEncoding != expectedEncoding) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"field=\""
|
||||||
|
+ field
|
||||||
|
+ "\" is encoded as: "
|
||||||
|
+ fieldEntry.vectorEncoding
|
||||||
|
+ " expected: "
|
||||||
|
+ expectedEncoding);
|
||||||
|
}
|
||||||
|
return fieldEntry;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
|
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
|
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
|
||||||
search(
|
search(
|
||||||
fields.get(field),
|
fieldEntry,
|
||||||
knnCollector,
|
knnCollector,
|
||||||
acceptDocs,
|
acceptDocs,
|
||||||
VectorEncoding.FLOAT32,
|
|
||||||
() -> flatVectorsReader.getRandomVectorScorer(field, target));
|
() -> flatVectorsReader.getRandomVectorScorer(field, target));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
|
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
|
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
|
||||||
search(
|
search(
|
||||||
fields.get(field),
|
fieldEntry,
|
||||||
knnCollector,
|
knnCollector,
|
||||||
acceptDocs,
|
acceptDocs,
|
||||||
VectorEncoding.BYTE,
|
|
||||||
() -> flatVectorsReader.getRandomVectorScorer(field, target));
|
() -> flatVectorsReader.getRandomVectorScorer(field, target));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -272,13 +288,10 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
|
||||||
FieldEntry fieldEntry,
|
FieldEntry fieldEntry,
|
||||||
KnnCollector knnCollector,
|
KnnCollector knnCollector,
|
||||||
Bits acceptDocs,
|
Bits acceptDocs,
|
||||||
VectorEncoding vectorEncoding,
|
|
||||||
IOSupplier<RandomVectorScorer> scorerSupplier)
|
IOSupplier<RandomVectorScorer> scorerSupplier)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
|
|
||||||
if (fieldEntry.size() == 0
|
if (fieldEntry.size() == 0 || knnCollector.k() == 0) {
|
||||||
|| knnCollector.k() == 0
|
|
||||||
|| fieldEntry.vectorEncoding != vectorEncoding) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
final RandomVectorScorer scorer = scorerSupplier.get();
|
final RandomVectorScorer scorer = scorerSupplier.get();
|
||||||
|
@ -304,12 +317,12 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public HnswGraph getGraph(String field) throws IOException {
|
public HnswGraph getGraph(String field) throws IOException {
|
||||||
FieldInfo info = fieldInfos.fieldInfo(field);
|
final FieldInfo info = fieldInfos.fieldInfo(field);
|
||||||
if (info == null) {
|
final FieldEntry entry;
|
||||||
throw new IllegalArgumentException("No such field '" + field + "'");
|
if (info == null || (entry = fields.get(info.number)) == null) {
|
||||||
|
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
||||||
}
|
}
|
||||||
FieldEntry entry = fields.get(field);
|
if (entry.vectorIndexLength > 0) {
|
||||||
if (entry != null && entry.vectorIndexLength > 0) {
|
|
||||||
return getGraph(entry);
|
return getGraph(entry);
|
||||||
} else {
|
} else {
|
||||||
return HnswGraph.EMPTY;
|
return HnswGraph.EMPTY;
|
||||||
|
|
|
@ -21,8 +21,6 @@ import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSi
|
||||||
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding;
|
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.Map;
|
|
||||||
import org.apache.lucene.codecs.CodecUtil;
|
import org.apache.lucene.codecs.CodecUtil;
|
||||||
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
|
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
|
||||||
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
|
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
|
||||||
|
@ -36,6 +34,7 @@ import org.apache.lucene.index.IndexFileNames;
|
||||||
import org.apache.lucene.index.SegmentReadState;
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
import org.apache.lucene.index.VectorEncoding;
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
|
import org.apache.lucene.internal.hppc.IntObjectHashMap;
|
||||||
import org.apache.lucene.search.VectorScorer;
|
import org.apache.lucene.search.VectorScorer;
|
||||||
import org.apache.lucene.store.ChecksumIndexInput;
|
import org.apache.lucene.store.ChecksumIndexInput;
|
||||||
import org.apache.lucene.store.IOContext;
|
import org.apache.lucene.store.IOContext;
|
||||||
|
@ -59,15 +58,17 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
|
||||||
private static final long SHALLOW_SIZE =
|
private static final long SHALLOW_SIZE =
|
||||||
RamUsageEstimator.shallowSizeOfInstance(Lucene99ScalarQuantizedVectorsReader.class);
|
RamUsageEstimator.shallowSizeOfInstance(Lucene99ScalarQuantizedVectorsReader.class);
|
||||||
|
|
||||||
private final Map<String, FieldEntry> fields = new HashMap<>();
|
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
|
||||||
private final IndexInput quantizedVectorData;
|
private final IndexInput quantizedVectorData;
|
||||||
private final FlatVectorsReader rawVectorsReader;
|
private final FlatVectorsReader rawVectorsReader;
|
||||||
|
private final FieldInfos fieldInfos;
|
||||||
|
|
||||||
public Lucene99ScalarQuantizedVectorsReader(
|
public Lucene99ScalarQuantizedVectorsReader(
|
||||||
SegmentReadState state, FlatVectorsReader rawVectorsReader, FlatVectorsScorer scorer)
|
SegmentReadState state, FlatVectorsReader rawVectorsReader, FlatVectorsScorer scorer)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
super(scorer);
|
super(scorer);
|
||||||
this.rawVectorsReader = rawVectorsReader;
|
this.rawVectorsReader = rawVectorsReader;
|
||||||
|
this.fieldInfos = state.fieldInfos;
|
||||||
int versionMeta = -1;
|
int versionMeta = -1;
|
||||||
String metaFileName =
|
String metaFileName =
|
||||||
IndexFileNames.segmentFileName(
|
IndexFileNames.segmentFileName(
|
||||||
|
@ -118,7 +119,7 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
|
||||||
}
|
}
|
||||||
FieldEntry fieldEntry = readField(meta, versionMeta, info);
|
FieldEntry fieldEntry = readField(meta, versionMeta, info);
|
||||||
validateFieldEntry(info, fieldEntry);
|
validateFieldEntry(info, fieldEntry);
|
||||||
fields.put(info.name, fieldEntry);
|
fields.put(info.number, fieldEntry);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -163,10 +164,10 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
|
||||||
CodecUtil.checksumEntireFile(quantizedVectorData);
|
CodecUtil.checksumEntireFile(quantizedVectorData);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
private FieldEntry getFieldEntry(String field) {
|
||||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
final FieldInfo info = fieldInfos.fieldInfo(field);
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
final FieldEntry fieldEntry;
|
||||||
if (fieldEntry == null) {
|
if (info == null || (fieldEntry = fields.get(info.number)) == null) {
|
||||||
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
||||||
}
|
}
|
||||||
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
||||||
|
@ -178,6 +179,12 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
|
||||||
+ " expected: "
|
+ " expected: "
|
||||||
+ VectorEncoding.FLOAT32);
|
+ VectorEncoding.FLOAT32);
|
||||||
}
|
}
|
||||||
|
return fieldEntry;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||||
|
final FieldEntry fieldEntry = getFieldEntry(field);
|
||||||
final FloatVectorValues rawVectorValues = rawVectorsReader.getFloatVectorValues(field);
|
final FloatVectorValues rawVectorValues = rawVectorsReader.getFloatVectorValues(field);
|
||||||
OffHeapQuantizedByteVectorValues quantizedByteVectorValues =
|
OffHeapQuantizedByteVectorValues quantizedByteVectorValues =
|
||||||
OffHeapQuantizedByteVectorValues.load(
|
OffHeapQuantizedByteVectorValues.load(
|
||||||
|
@ -241,10 +248,7 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException {
|
public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException {
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
final FieldEntry fieldEntry = getFieldEntry(field);
|
||||||
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
if (fieldEntry.scalarQuantizer == null) {
|
if (fieldEntry.scalarQuantizer == null) {
|
||||||
return rawVectorsReader.getRandomVectorScorer(field, target);
|
return rawVectorsReader.getRandomVectorScorer(field, target);
|
||||||
}
|
}
|
||||||
|
@ -275,12 +279,7 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long ramBytesUsed() {
|
public long ramBytesUsed() {
|
||||||
long size = SHALLOW_SIZE;
|
return SHALLOW_SIZE + fields.ramBytesUsed() + rawVectorsReader.ramBytesUsed();
|
||||||
size +=
|
|
||||||
RamUsageEstimator.sizeOfMap(
|
|
||||||
fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class));
|
|
||||||
size += rawVectorsReader.ramBytesUsed();
|
|
||||||
return size;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private FieldEntry readField(IndexInput input, int versionMeta, FieldInfo info)
|
private FieldEntry readField(IndexInput input, int versionMeta, FieldInfo info)
|
||||||
|
@ -301,11 +300,8 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public QuantizedByteVectorValues getQuantizedVectorValues(String fieldName) throws IOException {
|
public QuantizedByteVectorValues getQuantizedVectorValues(String field) throws IOException {
|
||||||
FieldEntry fieldEntry = fields.get(fieldName);
|
final FieldEntry fieldEntry = getFieldEntry(field);
|
||||||
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
return OffHeapQuantizedByteVectorValues.load(
|
return OffHeapQuantizedByteVectorValues.load(
|
||||||
fieldEntry.ordToDoc,
|
fieldEntry.ordToDoc,
|
||||||
fieldEntry.dimension,
|
fieldEntry.dimension,
|
||||||
|
@ -320,11 +316,8 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ScalarQuantizer getQuantizationState(String fieldName) {
|
public ScalarQuantizer getQuantizationState(String field) {
|
||||||
FieldEntry fieldEntry = fields.get(fieldName);
|
final FieldEntry fieldEntry = getFieldEntry(field);
|
||||||
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
return fieldEntry.scalarQuantizer;
|
return fieldEntry.scalarQuantizer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,9 @@ package org.apache.lucene.codecs.perfield;
|
||||||
|
|
||||||
import java.io.Closeable;
|
import java.io.Closeable;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.ServiceLoader;
|
import java.util.ServiceLoader;
|
||||||
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||||
|
@ -28,11 +30,14 @@ import org.apache.lucene.codecs.KnnVectorsReader;
|
||||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||||
import org.apache.lucene.index.ByteVectorValues;
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
import org.apache.lucene.index.FieldInfo;
|
import org.apache.lucene.index.FieldInfo;
|
||||||
|
import org.apache.lucene.index.FieldInfos;
|
||||||
import org.apache.lucene.index.FloatVectorValues;
|
import org.apache.lucene.index.FloatVectorValues;
|
||||||
import org.apache.lucene.index.MergeState;
|
import org.apache.lucene.index.MergeState;
|
||||||
import org.apache.lucene.index.SegmentReadState;
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
import org.apache.lucene.index.SegmentWriteState;
|
import org.apache.lucene.index.SegmentWriteState;
|
||||||
import org.apache.lucene.index.Sorter;
|
import org.apache.lucene.index.Sorter;
|
||||||
|
import org.apache.lucene.internal.hppc.IntObjectHashMap;
|
||||||
|
import org.apache.lucene.internal.hppc.ObjectCursor;
|
||||||
import org.apache.lucene.search.KnnCollector;
|
import org.apache.lucene.search.KnnCollector;
|
||||||
import org.apache.lucene.util.Bits;
|
import org.apache.lucene.util.Bits;
|
||||||
import org.apache.lucene.util.IOUtils;
|
import org.apache.lucene.util.IOUtils;
|
||||||
|
@ -186,7 +191,8 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
|
||||||
/** VectorReader that can wrap multiple delegate readers, selected by field. */
|
/** VectorReader that can wrap multiple delegate readers, selected by field. */
|
||||||
public static class FieldsReader extends KnnVectorsReader {
|
public static class FieldsReader extends KnnVectorsReader {
|
||||||
|
|
||||||
private final Map<String, KnnVectorsReader> fields = new HashMap<>();
|
private final IntObjectHashMap<KnnVectorsReader> fields = new IntObjectHashMap<>();
|
||||||
|
private final FieldInfos fieldInfos;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a FieldsReader over a segment, opening VectorReaders for each KnnVectorsFormat
|
* Create a FieldsReader over a segment, opening VectorReaders for each KnnVectorsFormat
|
||||||
|
@ -196,7 +202,7 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
|
||||||
* @throws IOException if one of the delegate readers throws
|
* @throws IOException if one of the delegate readers throws
|
||||||
*/
|
*/
|
||||||
public FieldsReader(final SegmentReadState readState) throws IOException {
|
public FieldsReader(final SegmentReadState readState) throws IOException {
|
||||||
|
this.fieldInfos = readState.fieldInfos;
|
||||||
// Init each unique format:
|
// Init each unique format:
|
||||||
boolean success = false;
|
boolean success = false;
|
||||||
Map<String, KnnVectorsReader> formats = new HashMap<>();
|
Map<String, KnnVectorsReader> formats = new HashMap<>();
|
||||||
|
@ -221,7 +227,7 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
|
||||||
segmentSuffix,
|
segmentSuffix,
|
||||||
format.fieldsReader(new SegmentReadState(readState, segmentSuffix)));
|
format.fieldsReader(new SegmentReadState(readState, segmentSuffix)));
|
||||||
}
|
}
|
||||||
fields.put(fieldName, formats.get(segmentSuffix));
|
fields.put(fi.number, formats.get(segmentSuffix));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -239,51 +245,69 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
|
||||||
* @param field the name of a numeric vector field
|
* @param field the name of a numeric vector field
|
||||||
*/
|
*/
|
||||||
public KnnVectorsReader getFieldReader(String field) {
|
public KnnVectorsReader getFieldReader(String field) {
|
||||||
return fields.get(field);
|
final FieldInfo info = fieldInfos.fieldInfo(field);
|
||||||
|
if (info == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return fields.get(info.number);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void checkIntegrity() throws IOException {
|
public void checkIntegrity() throws IOException {
|
||||||
for (KnnVectorsReader reader : fields.values()) {
|
for (ObjectCursor<KnnVectorsReader> cursor : fields.values()) {
|
||||||
reader.checkIntegrity();
|
cursor.value.checkIntegrity();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||||
KnnVectorsReader knnVectorsReader = fields.get(field);
|
final FieldInfo info = fieldInfos.fieldInfo(field);
|
||||||
if (knnVectorsReader == null) {
|
final KnnVectorsReader reader;
|
||||||
|
if (info == null || (reader = fields.get(info.number)) == null) {
|
||||||
return null;
|
return null;
|
||||||
} else {
|
|
||||||
return knnVectorsReader.getFloatVectorValues(field);
|
|
||||||
}
|
}
|
||||||
|
return reader.getFloatVectorValues(field);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||||
KnnVectorsReader knnVectorsReader = fields.get(field);
|
final FieldInfo info = fieldInfos.fieldInfo(field);
|
||||||
if (knnVectorsReader == null) {
|
final KnnVectorsReader reader;
|
||||||
|
if (info == null || (reader = fields.get(info.number)) == null) {
|
||||||
return null;
|
return null;
|
||||||
} else {
|
|
||||||
return knnVectorsReader.getByteVectorValues(field);
|
|
||||||
}
|
}
|
||||||
|
return reader.getByteVectorValues(field);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
|
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
fields.get(field).search(field, target, knnCollector, acceptDocs);
|
final FieldInfo info = fieldInfos.fieldInfo(field);
|
||||||
|
final KnnVectorsReader reader;
|
||||||
|
if (info == null || (reader = fields.get(info.number)) == null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
reader.search(field, target, knnCollector, acceptDocs);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
|
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
fields.get(field).search(field, target, knnCollector, acceptDocs);
|
final FieldInfo info = fieldInfos.fieldInfo(field);
|
||||||
|
final KnnVectorsReader reader;
|
||||||
|
if (info == null || (reader = fields.get(info.number)) == null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
reader.search(field, target, knnCollector, acceptDocs);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void close() throws IOException {
|
public void close() throws IOException {
|
||||||
IOUtils.close(fields.values());
|
List<KnnVectorsReader> readers = new ArrayList<>(fields.size());
|
||||||
|
for (ObjectCursor<KnnVectorsReader> cursor : fields.values()) {
|
||||||
|
readers.add(cursor.value);
|
||||||
|
}
|
||||||
|
IOUtils.close(readers);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue