Use FieldInfo vector similarity in knn readers (#13237)

Both the KnnWriters & FieldInfo keep track of the vector similarity used by a given field. This commit ensures they are the same and utilizes the FieldInfo one (which, while these are enums, are exactly the same).
This commit is contained in:
Benjamin Trent 2024-03-29 09:10:39 -04:00 committed by GitHub
parent c41eb227ea
commit 69172b14ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 96 additions and 24 deletions

View File

@ -158,7 +158,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
} }
FieldEntry fieldEntry = readField(meta); FieldEntry fieldEntry = readField(meta, info);
validateFieldEntry(info, fieldEntry); validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry); fields.put(info.name, fieldEntry);
} }
@ -200,9 +200,18 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
return VectorSimilarityFunction.values()[similarityFunctionId]; return VectorSimilarityFunction.values()[similarityFunctionId];
} }
private FieldEntry readField(DataInput input) throws IOException { private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException {
VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
return new FieldEntry(input, similarityFunction); if (similarityFunction != info.getVectorSimilarityFunction()) {
throw new IllegalStateException(
"Inconsistent vector similarity function for field=\""
+ info.name
+ "\"; "
+ similarityFunction
+ " != "
+ info.getVectorSimilarityFunction());
}
return new FieldEntry(input, info.getVectorSimilarityFunction());
} }
@Override @Override

View File

@ -150,7 +150,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
if (info == null) { if (info == null) {
throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
} }
FieldEntry fieldEntry = readField(meta); FieldEntry fieldEntry = readField(meta, info);
validateFieldEntry(info, fieldEntry); validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry); fields.put(info.name, fieldEntry);
} }
@ -192,9 +192,18 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
return VectorSimilarityFunction.values()[similarityFunctionId]; return VectorSimilarityFunction.values()[similarityFunctionId];
} }
private FieldEntry readField(DataInput input) throws IOException { private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException {
VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
return new FieldEntry(input, similarityFunction); if (similarityFunction != info.getVectorSimilarityFunction()) {
throw new IllegalStateException(
"Inconsistent vector similarity function for field=\""
+ info.name
+ "\"; "
+ similarityFunction
+ " != "
+ info.getVectorSimilarityFunction());
}
return new FieldEntry(input, info.getVectorSimilarityFunction());
} }
@Override @Override

View File

@ -149,7 +149,7 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
if (info == null) { if (info == null) {
throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
} }
FieldEntry fieldEntry = readField(meta); FieldEntry fieldEntry = readField(meta, info);
validateFieldEntry(info, fieldEntry); validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry); fields.put(info.name, fieldEntry);
} }
@ -191,9 +191,18 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
return VectorSimilarityFunction.values()[similarityFunctionId]; return VectorSimilarityFunction.values()[similarityFunctionId];
} }
private FieldEntry readField(IndexInput input) throws IOException { private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException {
VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
return new FieldEntry(input, similarityFunction); if (similarityFunction != info.getVectorSimilarityFunction()) {
throw new IllegalStateException(
"Inconsistent vector similarity function for field=\""
+ info.name
+ "\"; "
+ similarityFunction
+ " != "
+ info.getVectorSimilarityFunction());
}
return new FieldEntry(input, info.getVectorSimilarityFunction());
} }
@Override @Override

View File

@ -150,7 +150,7 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
if (info == null) { if (info == null) {
throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
} }
FieldEntry fieldEntry = readField(meta); FieldEntry fieldEntry = readField(meta, info);
validateFieldEntry(info, fieldEntry); validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry); fields.put(info.name, fieldEntry);
} }
@ -208,10 +208,19 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
return VectorEncoding.values()[encodingId]; return VectorEncoding.values()[encodingId];
} }
private FieldEntry readField(IndexInput input) throws IOException { private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException {
VectorEncoding vectorEncoding = readVectorEncoding(input); VectorEncoding vectorEncoding = readVectorEncoding(input);
VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
return new FieldEntry(input, vectorEncoding, similarityFunction); if (similarityFunction != info.getVectorSimilarityFunction()) {
throw new IllegalStateException(
"Inconsistent vector similarity function for field=\""
+ info.name
+ "\"; "
+ similarityFunction
+ " != "
+ info.getVectorSimilarityFunction());
}
return new FieldEntry(input, vectorEncoding, info.getVectorSimilarityFunction());
} }
@Override @Override

View File

@ -162,7 +162,7 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
if (info == null) { if (info == null) {
throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
} }
FieldEntry fieldEntry = readField(meta); FieldEntry fieldEntry = readField(meta, info);
validateFieldEntry(info, fieldEntry); validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry); fields.put(info.name, fieldEntry);
} }
@ -220,10 +220,19 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
return VectorEncoding.values()[encodingId]; return VectorEncoding.values()[encodingId];
} }
private FieldEntry readField(IndexInput input) throws IOException { private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException {
VectorEncoding vectorEncoding = readVectorEncoding(input); VectorEncoding vectorEncoding = readVectorEncoding(input);
VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
return new FieldEntry(input, vectorEncoding, similarityFunction); if (similarityFunction != info.getVectorSimilarityFunction()) {
throw new IllegalStateException(
"Inconsistent vector similarity function for field=\""
+ info.name
+ "\"; "
+ similarityFunction
+ " != "
+ info.getVectorSimilarityFunction());
}
return new FieldEntry(input, vectorEncoding, info.getVectorSimilarityFunction());
} }
@Override @Override

View File

@ -143,7 +143,7 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
if (info == null) { if (info == null) {
throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
} }
FieldEntry fieldEntry = readField(meta); FieldEntry fieldEntry = readField(meta, info);
validateFieldEntry(info, fieldEntry); validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry); fields.put(info.name, fieldEntry);
} }
@ -183,10 +183,19 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
} }
} }
private FieldEntry readField(IndexInput input) throws IOException { private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException {
VectorEncoding vectorEncoding = readVectorEncoding(input); VectorEncoding vectorEncoding = readVectorEncoding(input);
VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
return new FieldEntry(input, vectorEncoding, similarityFunction); if (similarityFunction != info.getVectorSimilarityFunction()) {
throw new IllegalStateException(
"Inconsistent vector similarity function for field=\""
+ info.name
+ "\"; "
+ similarityFunction
+ " != "
+ info.getVectorSimilarityFunction());
}
return new FieldEntry(input, vectorEncoding, info.getVectorSimilarityFunction());
} }
@Override @Override

View File

@ -153,7 +153,7 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
if (info == null) { if (info == null) {
throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
} }
FieldEntry fieldEntry = readField(meta); FieldEntry fieldEntry = readField(meta, info);
validateFieldEntry(info, fieldEntry); validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry); fields.put(info.name, fieldEntry);
} }
@ -200,10 +200,19 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
return VectorEncoding.values()[encodingId]; return VectorEncoding.values()[encodingId];
} }
private FieldEntry readField(IndexInput input) throws IOException { private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException {
VectorEncoding vectorEncoding = readVectorEncoding(input); VectorEncoding vectorEncoding = readVectorEncoding(input);
VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
return new FieldEntry(input, vectorEncoding, similarityFunction); if (similarityFunction != info.getVectorSimilarityFunction()) {
throw new IllegalStateException(
"Inconsistent vector similarity function for field=\""
+ info.name
+ "\"; "
+ similarityFunction
+ " != "
+ info.getVectorSimilarityFunction());
}
return new FieldEntry(input, vectorEncoding, info.getVectorSimilarityFunction());
} }
@Override @Override

View File

@ -108,7 +108,7 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
if (info == null) { if (info == null) {
throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
} }
FieldEntry fieldEntry = readField(meta); FieldEntry fieldEntry = readField(meta, info);
validateFieldEntry(info, fieldEntry); validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry); fields.put(info.name, fieldEntry);
} }
@ -236,10 +236,19 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
return size; return size;
} }
private FieldEntry readField(IndexInput input) throws IOException { private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException {
VectorEncoding vectorEncoding = readVectorEncoding(input); VectorEncoding vectorEncoding = readVectorEncoding(input);
VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
return new FieldEntry(input, vectorEncoding, similarityFunction); if (similarityFunction != info.getVectorSimilarityFunction()) {
throw new IllegalStateException(
"Inconsistent vector similarity function for field=\""
+ info.name
+ "\"; "
+ similarityFunction
+ " != "
+ info.getVectorSimilarityFunction());
}
return new FieldEntry(input, vectorEncoding, info.getVectorSimilarityFunction());
} }
@Override @Override