mirror of https://github.com/apache/lucene.git
Unify how missing field entries are handle in knn formats (#13641)
During segment merge we must verify that a given field has vectors and exists. The typical knn format checks assume the per-field format is used and thus only check for `null`. But we should check for field existence in the field info and verify it has dense vectors Additionally, this commit unifies how the knn formats work and they will throw if a non-existing field is queried. Except for PerField format, which will return null (like the other per field formats)
This commit is contained in:
parent
8760654faa
commit
74e3c44063
|
@ -452,6 +452,9 @@ Bug Fixes
|
||||||
* GITHUB#13703: Fix bug in LatLonPoint queries where narrow polygons close to latitude 90 don't
|
* GITHUB#13703: Fix bug in LatLonPoint queries where narrow polygons close to latitude 90 don't
|
||||||
match any points due to an Integer overflow. (Ignacio Vera)
|
match any points due to an Integer overflow. (Ignacio Vera)
|
||||||
|
|
||||||
|
* GITHUB#13641: Unify how KnnFormats handle missing fields and correctly handle missing vector fields when
|
||||||
|
merging segments. (Ben Trent)
|
||||||
|
|
||||||
Build
|
Build
|
||||||
---------------------
|
---------------------
|
||||||
|
|
||||||
|
|
|
@ -224,6 +224,9 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
||||||
@Override
|
@Override
|
||||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
FieldEntry fieldEntry = fields.get(field);
|
||||||
|
if (fieldEntry == null) {
|
||||||
|
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
||||||
|
}
|
||||||
return getOffHeapVectorValues(fieldEntry);
|
return getOffHeapVectorValues(fieldEntry);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -218,6 +218,9 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
||||||
@Override
|
@Override
|
||||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
FieldEntry fieldEntry = fields.get(field);
|
||||||
|
if (fieldEntry == null) {
|
||||||
|
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
||||||
|
}
|
||||||
return getOffHeapVectorValues(fieldEntry);
|
return getOffHeapVectorValues(fieldEntry);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -215,6 +215,9 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
|
||||||
@Override
|
@Override
|
||||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
FieldEntry fieldEntry = fields.get(field);
|
||||||
|
if (fieldEntry == null) {
|
||||||
|
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
||||||
|
}
|
||||||
return OffHeapFloatVectorValues.load(fieldEntry, vectorData);
|
return OffHeapFloatVectorValues.load(fieldEntry, vectorData);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -233,6 +233,9 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
||||||
@Override
|
@Override
|
||||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
FieldEntry fieldEntry = fields.get(field);
|
||||||
|
if (fieldEntry == null) {
|
||||||
|
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
||||||
|
}
|
||||||
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
"field=\""
|
"field=\""
|
||||||
|
@ -248,6 +251,9 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
||||||
@Override
|
@Override
|
||||||
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
FieldEntry fieldEntry = fields.get(field);
|
||||||
|
if (fieldEntry == null) {
|
||||||
|
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
||||||
|
}
|
||||||
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
"field=\""
|
"field=\""
|
||||||
|
|
|
@ -241,6 +241,9 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
|
||||||
@Override
|
@Override
|
||||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
FieldEntry fieldEntry = fields.get(field);
|
||||||
|
if (fieldEntry == null) {
|
||||||
|
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
||||||
|
}
|
||||||
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
"field=\""
|
"field=\""
|
||||||
|
@ -264,6 +267,9 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
|
||||||
@Override
|
@Override
|
||||||
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
FieldEntry fieldEntry = fields.get(field);
|
||||||
|
if (fieldEntry == null) {
|
||||||
|
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
||||||
|
}
|
||||||
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
"field=\""
|
"field=\""
|
||||||
|
|
|
@ -78,4 +78,9 @@ public class TestLucene90HnswVectorsFormat extends BaseKnnVectorsFormatTestCase
|
||||||
public void testEmptyByteVectorData() {
|
public void testEmptyByteVectorData() {
|
||||||
// unimplemented
|
// unimplemented
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void testMergingWithDifferentByteKnnFields() {
|
||||||
|
// unimplemented
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -77,4 +77,9 @@ public class TestLucene91HnswVectorsFormat extends BaseKnnVectorsFormatTestCase
|
||||||
public void testEmptyByteVectorData() {
|
public void testEmptyByteVectorData() {
|
||||||
// unimplemented
|
// unimplemented
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void testMergingWithDifferentByteKnnFields() {
|
||||||
|
// unimplemented
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -67,4 +67,9 @@ public class TestLucene92HnswVectorsFormat extends BaseKnnVectorsFormatTestCase
|
||||||
public void testEmptyByteVectorData() {
|
public void testEmptyByteVectorData() {
|
||||||
// unimplemented
|
// unimplemented
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void testMergingWithDifferentByteKnnFields() {
|
||||||
|
// unimplemented
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
package org.apache.lucene.backward_codecs.lucene95;
|
package org.apache.lucene.backward_codecs.lucene95;
|
||||||
|
|
||||||
import static org.apache.lucene.backward_codecs.lucene95.Lucene95HnswVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT;
|
import static org.apache.lucene.backward_codecs.lucene95.Lucene95HnswVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT;
|
||||||
|
import static org.apache.lucene.codecs.KnnVectorsWriter.MergedVectorValues.hasVectorValues;
|
||||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -476,9 +477,11 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
IncrementalHnswGraphMerger merger =
|
IncrementalHnswGraphMerger merger =
|
||||||
new IncrementalHnswGraphMerger(fieldInfo, scorerSupplier, M, beamWidth);
|
new IncrementalHnswGraphMerger(fieldInfo, scorerSupplier, M, beamWidth);
|
||||||
for (int i = 0; i < mergeState.liveDocs.length; i++) {
|
for (int i = 0; i < mergeState.liveDocs.length; i++) {
|
||||||
|
if (hasVectorValues(mergeState.fieldInfos[i], fieldInfo.name)) {
|
||||||
merger.addReader(
|
merger.addReader(
|
||||||
mergeState.knnVectorsReaders[i], mergeState.docMaps[i], mergeState.liveDocs[i]);
|
mergeState.knnVectorsReaders[i], mergeState.docMaps[i], mergeState.liveDocs[i]);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
DocIdSetIterator mergedVectorIterator = null;
|
DocIdSetIterator mergedVectorIterator = null;
|
||||||
switch (fieldInfo.getVectorEncoding()) {
|
switch (fieldInfo.getVectorEncoding()) {
|
||||||
case BYTE ->
|
case BYTE ->
|
||||||
|
|
|
@ -28,6 +28,7 @@ import org.apache.lucene.index.ByteVectorValues;
|
||||||
import org.apache.lucene.index.DocIDMerger;
|
import org.apache.lucene.index.DocIDMerger;
|
||||||
import org.apache.lucene.index.DocsWithFieldSet;
|
import org.apache.lucene.index.DocsWithFieldSet;
|
||||||
import org.apache.lucene.index.FieldInfo;
|
import org.apache.lucene.index.FieldInfo;
|
||||||
|
import org.apache.lucene.index.FieldInfos;
|
||||||
import org.apache.lucene.index.FloatVectorValues;
|
import org.apache.lucene.index.FloatVectorValues;
|
||||||
import org.apache.lucene.index.MergeState;
|
import org.apache.lucene.index.MergeState;
|
||||||
import org.apache.lucene.index.Sorter;
|
import org.apache.lucene.index.Sorter;
|
||||||
|
@ -212,14 +213,35 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns true if the fieldInfos has vector values for the field.
|
||||||
|
*
|
||||||
|
* @param fieldInfos fieldInfos for the segment
|
||||||
|
* @param fieldName field name
|
||||||
|
* @return true if the fieldInfos has vector values for the field.
|
||||||
|
*/
|
||||||
|
public static boolean hasVectorValues(FieldInfos fieldInfos, String fieldName) {
|
||||||
|
if (fieldInfos.hasVectorValues() == false) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
FieldInfo info = fieldInfos.fieldInfo(fieldName);
|
||||||
|
return info != null && info.hasVectorValues();
|
||||||
|
}
|
||||||
|
|
||||||
private static <V, S> List<S> mergeVectorValues(
|
private static <V, S> List<S> mergeVectorValues(
|
||||||
KnnVectorsReader[] knnVectorsReaders,
|
KnnVectorsReader[] knnVectorsReaders,
|
||||||
MergeState.DocMap[] docMaps,
|
MergeState.DocMap[] docMaps,
|
||||||
|
FieldInfo mergingField,
|
||||||
|
FieldInfos[] sourceFieldInfos,
|
||||||
IOFunction<KnnVectorsReader, V> valuesSupplier,
|
IOFunction<KnnVectorsReader, V> valuesSupplier,
|
||||||
BiFunction<MergeState.DocMap, V, S> newSub)
|
BiFunction<MergeState.DocMap, V, S> newSub)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
List<S> subs = new ArrayList<>();
|
List<S> subs = new ArrayList<>();
|
||||||
for (int i = 0; i < knnVectorsReaders.length; i++) {
|
for (int i = 0; i < knnVectorsReaders.length; i++) {
|
||||||
|
FieldInfos sourceFieldInfo = sourceFieldInfos[i];
|
||||||
|
if (hasVectorValues(sourceFieldInfo, mergingField.name) == false) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
KnnVectorsReader knnVectorsReader = knnVectorsReaders[i];
|
KnnVectorsReader knnVectorsReader = knnVectorsReaders[i];
|
||||||
if (knnVectorsReader != null) {
|
if (knnVectorsReader != null) {
|
||||||
V values = valuesSupplier.apply(knnVectorsReader);
|
V values = valuesSupplier.apply(knnVectorsReader);
|
||||||
|
@ -239,12 +261,10 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
|
||||||
mergeVectorValues(
|
mergeVectorValues(
|
||||||
mergeState.knnVectorsReaders,
|
mergeState.knnVectorsReaders,
|
||||||
mergeState.docMaps,
|
mergeState.docMaps,
|
||||||
knnVectorsReader -> {
|
fieldInfo,
|
||||||
return knnVectorsReader.getFloatVectorValues(fieldInfo.name);
|
mergeState.fieldInfos,
|
||||||
},
|
knnVectorsReader -> knnVectorsReader.getFloatVectorValues(fieldInfo.name),
|
||||||
(docMap, values) -> {
|
FloatVectorValuesSub::new),
|
||||||
return new FloatVectorValuesSub(docMap, values);
|
|
||||||
}),
|
|
||||||
mergeState);
|
mergeState);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -256,12 +276,10 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
|
||||||
mergeVectorValues(
|
mergeVectorValues(
|
||||||
mergeState.knnVectorsReaders,
|
mergeState.knnVectorsReaders,
|
||||||
mergeState.docMaps,
|
mergeState.docMaps,
|
||||||
knnVectorsReader -> {
|
fieldInfo,
|
||||||
return knnVectorsReader.getByteVectorValues(fieldInfo.name);
|
mergeState.fieldInfos,
|
||||||
},
|
knnVectorsReader -> knnVectorsReader.getByteVectorValues(fieldInfo.name),
|
||||||
(docMap, values) -> {
|
ByteVectorValuesSub::new),
|
||||||
return new ByteVectorValuesSub(docMap, values);
|
|
||||||
}),
|
|
||||||
mergeState);
|
mergeState);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -174,6 +174,9 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
|
||||||
@Override
|
@Override
|
||||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
FieldEntry fieldEntry = fields.get(field);
|
||||||
|
if (fieldEntry == null) {
|
||||||
|
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
||||||
|
}
|
||||||
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
"field=\""
|
"field=\""
|
||||||
|
@ -197,6 +200,9 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
|
||||||
@Override
|
@Override
|
||||||
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
FieldEntry fieldEntry = fields.get(field);
|
||||||
|
if (fieldEntry == null) {
|
||||||
|
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
||||||
|
}
|
||||||
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
"field=\""
|
"field=\""
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
package org.apache.lucene.codecs.lucene99;
|
package org.apache.lucene.codecs.lucene99;
|
||||||
|
|
||||||
|
import static org.apache.lucene.codecs.KnnVectorsWriter.MergedVectorValues.hasVectorValues;
|
||||||
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT;
|
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT;
|
||||||
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS;
|
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS;
|
||||||
|
|
||||||
|
@ -353,9 +354,11 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
: new TaskExecutor(mergeState.intraMergeTaskExecutor),
|
: new TaskExecutor(mergeState.intraMergeTaskExecutor),
|
||||||
numMergeWorkers);
|
numMergeWorkers);
|
||||||
for (int i = 0; i < mergeState.liveDocs.length; i++) {
|
for (int i = 0; i < mergeState.liveDocs.length; i++) {
|
||||||
|
if (hasVectorValues(mergeState.fieldInfos[i], fieldInfo.name)) {
|
||||||
merger.addReader(
|
merger.addReader(
|
||||||
mergeState.knnVectorsReaders[i], mergeState.docMaps[i], mergeState.liveDocs[i]);
|
mergeState.knnVectorsReaders[i], mergeState.docMaps[i], mergeState.liveDocs[i]);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
DocIdSetIterator mergedVectorIterator = null;
|
DocIdSetIterator mergedVectorIterator = null;
|
||||||
switch (fieldInfo.getVectorEncoding()) {
|
switch (fieldInfo.getVectorEncoding()) {
|
||||||
case BYTE ->
|
case BYTE ->
|
||||||
|
|
|
@ -165,8 +165,17 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
|
||||||
@Override
|
@Override
|
||||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
FieldEntry fieldEntry = fields.get(field);
|
||||||
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
if (fieldEntry == null) {
|
||||||
return null;
|
throw new IllegalArgumentException("field=\"" + field + "\" not found");
|
||||||
|
}
|
||||||
|
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"field=\""
|
||||||
|
+ field
|
||||||
|
+ "\" is encoded as: "
|
||||||
|
+ fieldEntry.vectorEncoding
|
||||||
|
+ " expected: "
|
||||||
|
+ VectorEncoding.FLOAT32);
|
||||||
}
|
}
|
||||||
final FloatVectorValues rawVectorValues = rawVectorsReader.getFloatVectorValues(field);
|
final FloatVectorValues rawVectorValues = rawVectorsReader.getFloatVectorValues(field);
|
||||||
OffHeapQuantizedByteVectorValues quantizedByteVectorValues =
|
OffHeapQuantizedByteVectorValues quantizedByteVectorValues =
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
package org.apache.lucene.codecs.lucene99;
|
package org.apache.lucene.codecs.lucene99;
|
||||||
|
|
||||||
|
import static org.apache.lucene.codecs.KnnVectorsWriter.MergedVectorValues.hasVectorValues;
|
||||||
import static org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT;
|
import static org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT;
|
||||||
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.DYNAMIC_CONFIDENCE_INTERVAL;
|
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.DYNAMIC_CONFIDENCE_INTERVAL;
|
||||||
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_COMPONENT;
|
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_COMPONENT;
|
||||||
|
@ -630,7 +631,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
||||||
IntArrayList segmentSizes = new IntArrayList(mergeState.liveDocs.length);
|
IntArrayList segmentSizes = new IntArrayList(mergeState.liveDocs.length);
|
||||||
for (int i = 0; i < mergeState.liveDocs.length; i++) {
|
for (int i = 0; i < mergeState.liveDocs.length; i++) {
|
||||||
FloatVectorValues fvv;
|
FloatVectorValues fvv;
|
||||||
if (mergeState.knnVectorsReaders[i] != null
|
if (hasVectorValues(mergeState.fieldInfos[i], fieldInfo.name)
|
||||||
&& (fvv = mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name)) != null
|
&& (fvv = mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name)) != null
|
||||||
&& fvv.size() > 0) {
|
&& fvv.size() > 0) {
|
||||||
ScalarQuantizer quantizationState =
|
ScalarQuantizer quantizationState =
|
||||||
|
@ -928,8 +929,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
||||||
|
|
||||||
List<QuantizedByteVectorValueSub> subs = new ArrayList<>();
|
List<QuantizedByteVectorValueSub> subs = new ArrayList<>();
|
||||||
for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
|
for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
|
||||||
if (mergeState.knnVectorsReaders[i] != null
|
if (hasVectorValues(mergeState.fieldInfos[i], fieldInfo.name)) {
|
||||||
&& mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name) != null) {
|
|
||||||
QuantizedVectorsReader reader =
|
QuantizedVectorsReader reader =
|
||||||
getQuantizedKnnVectorsReader(mergeState.knnVectorsReaders[i], fieldInfo.name);
|
getQuantizedKnnVectorsReader(mergeState.knnVectorsReaders[i], fieldInfo.name);
|
||||||
assert scalarQuantizer != null;
|
assert scalarQuantizer != null;
|
||||||
|
|
|
@ -27,11 +27,14 @@ import java.util.Collections;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
import java.util.concurrent.atomic.AtomicReference;
|
||||||
import org.apache.lucene.codecs.Codec;
|
import org.apache.lucene.codecs.Codec;
|
||||||
|
import org.apache.lucene.codecs.FilterCodec;
|
||||||
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||||
|
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
||||||
import org.apache.lucene.document.Document;
|
import org.apache.lucene.document.Document;
|
||||||
import org.apache.lucene.document.Field;
|
import org.apache.lucene.document.Field;
|
||||||
import org.apache.lucene.document.KnnByteVectorField;
|
import org.apache.lucene.document.KnnByteVectorField;
|
||||||
|
@ -53,6 +56,9 @@ import org.apache.lucene.index.IndexWriter;
|
||||||
import org.apache.lucene.index.IndexWriterConfig;
|
import org.apache.lucene.index.IndexWriterConfig;
|
||||||
import org.apache.lucene.index.LeafReader;
|
import org.apache.lucene.index.LeafReader;
|
||||||
import org.apache.lucene.index.LeafReaderContext;
|
import org.apache.lucene.index.LeafReaderContext;
|
||||||
|
import org.apache.lucene.index.MergePolicy;
|
||||||
|
import org.apache.lucene.index.MergeScheduler;
|
||||||
|
import org.apache.lucene.index.MergeTrigger;
|
||||||
import org.apache.lucene.index.NoMergePolicy;
|
import org.apache.lucene.index.NoMergePolicy;
|
||||||
import org.apache.lucene.index.SegmentInfo;
|
import org.apache.lucene.index.SegmentInfo;
|
||||||
import org.apache.lucene.index.SegmentWriteState;
|
import org.apache.lucene.index.SegmentWriteState;
|
||||||
|
@ -230,6 +236,106 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testMergingWithDifferentKnnFields() throws Exception {
|
||||||
|
try (var dir = newDirectory()) {
|
||||||
|
IndexWriterConfig iwc = new IndexWriterConfig();
|
||||||
|
Codec codec = getCodec();
|
||||||
|
if (codec.knnVectorsFormat() instanceof PerFieldKnnVectorsFormat perFieldKnnVectorsFormat) {
|
||||||
|
final KnnVectorsFormat format =
|
||||||
|
perFieldKnnVectorsFormat.getKnnVectorsFormatForField("field");
|
||||||
|
iwc.setCodec(
|
||||||
|
new FilterCodec(codec.getName(), codec) {
|
||||||
|
@Override
|
||||||
|
public KnnVectorsFormat knnVectorsFormat() {
|
||||||
|
return format;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
TestMergeScheduler mergeScheduler = new TestMergeScheduler();
|
||||||
|
iwc.setMergeScheduler(mergeScheduler);
|
||||||
|
iwc.setMergePolicy(new ForceMergePolicy(iwc.getMergePolicy()));
|
||||||
|
try (var writer = new IndexWriter(dir, iwc)) {
|
||||||
|
for (int i = 0; i < 10; i++) {
|
||||||
|
var doc = new Document();
|
||||||
|
doc.add(new KnnFloatVectorField("field", new float[] {i, i + 1, i + 2, i + 3}));
|
||||||
|
writer.addDocument(doc);
|
||||||
|
}
|
||||||
|
writer.commit();
|
||||||
|
for (int i = 0; i < 10; i++) {
|
||||||
|
var doc = new Document();
|
||||||
|
doc.add(new KnnFloatVectorField("otherVector", new float[] {i, i, i, i}));
|
||||||
|
writer.addDocument(doc);
|
||||||
|
}
|
||||||
|
writer.commit();
|
||||||
|
writer.forceMerge(1);
|
||||||
|
assertNull(mergeScheduler.ex.get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testMergingWithDifferentByteKnnFields() throws Exception {
|
||||||
|
try (var dir = newDirectory()) {
|
||||||
|
IndexWriterConfig iwc = new IndexWriterConfig();
|
||||||
|
Codec codec = getCodec();
|
||||||
|
if (codec.knnVectorsFormat() instanceof PerFieldKnnVectorsFormat perFieldKnnVectorsFormat) {
|
||||||
|
final KnnVectorsFormat format =
|
||||||
|
perFieldKnnVectorsFormat.getKnnVectorsFormatForField("field");
|
||||||
|
iwc.setCodec(
|
||||||
|
new FilterCodec(codec.getName(), codec) {
|
||||||
|
@Override
|
||||||
|
public KnnVectorsFormat knnVectorsFormat() {
|
||||||
|
return format;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
TestMergeScheduler mergeScheduler = new TestMergeScheduler();
|
||||||
|
iwc.setMergeScheduler(mergeScheduler);
|
||||||
|
iwc.setMergePolicy(new ForceMergePolicy(iwc.getMergePolicy()));
|
||||||
|
try (var writer = new IndexWriter(dir, iwc)) {
|
||||||
|
for (int i = 0; i < 10; i++) {
|
||||||
|
var doc = new Document();
|
||||||
|
doc.add(
|
||||||
|
new KnnByteVectorField("field", new byte[] {(byte) i, (byte) i, (byte) i, (byte) i}));
|
||||||
|
writer.addDocument(doc);
|
||||||
|
}
|
||||||
|
writer.commit();
|
||||||
|
for (int i = 0; i < 10; i++) {
|
||||||
|
var doc = new Document();
|
||||||
|
doc.add(
|
||||||
|
new KnnByteVectorField(
|
||||||
|
"otherVector", new byte[] {(byte) i, (byte) i, (byte) i, (byte) i}));
|
||||||
|
writer.addDocument(doc);
|
||||||
|
}
|
||||||
|
writer.commit();
|
||||||
|
writer.forceMerge(1);
|
||||||
|
assertNull(mergeScheduler.ex.get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final class TestMergeScheduler extends MergeScheduler {
|
||||||
|
AtomicReference<Exception> ex = new AtomicReference<>();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void merge(MergeSource mergeSource, MergeTrigger trigger) throws IOException {
|
||||||
|
while (true) {
|
||||||
|
MergePolicy.OneMerge merge = mergeSource.getNextMerge();
|
||||||
|
if (merge == null) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
mergeSource.merge(merge);
|
||||||
|
} catch (IllegalStateException | IllegalArgumentException e) {
|
||||||
|
ex.set(e);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {}
|
||||||
|
}
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
public void testWriterRamEstimate() throws Exception {
|
public void testWriterRamEstimate() throws Exception {
|
||||||
final FieldInfos fieldInfos = new FieldInfos(new FieldInfo[0]);
|
final FieldInfos fieldInfos = new FieldInfos(new FieldInfo[0]);
|
||||||
|
|
Loading…
Reference in New Issue