Fix vector type check for diversified knn search (#13235)

I repeatably saw some test failures related to `TestParentBlockJoin[Byte|Float]KnnVectorQuery#testVectorEncodingMismatch`. This commit fixes those test failures and actually checks the field type.
This commit is contained in:
Benjamin Trent 2024-03-29 13:49:58 -04:00 committed by GitHub
parent 6d120c49a4
commit c0d81932df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 47 additions and 24 deletions

View File

@ -136,6 +136,7 @@ public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery {
int visitedLimit,
KnnCollectorManager knnCollectorManager)
throws IOException {
ByteVectorValues.checkField(context.reader(), field);
KnnCollector collector = knnCollectorManager.newCollector(visitedLimit, context);
if (collector == null) {
return NO_RESULTS;

View File

@ -136,6 +136,7 @@ public class DiversifyingChildrenFloatKnnVectorQuery extends KnnFloatVectorQuery
int visitedLimit,
KnnCollectorManager knnCollectorManager)
throws IOException {
FloatVectorValues.checkField(context.reader(), field);
KnnCollector collector = knnCollectorManager.newCollector(visitedLimit, context);
if (collector == null) {
return NO_RESULTS;

View File

@ -17,16 +17,21 @@
package org.apache.lucene.search.join;
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.store.Directory;
public class TestParentBlockJoinByteKnnVectorQuery extends ParentBlockJoinKnnVectorQueryTestCase {
@ -54,16 +59,25 @@ public class TestParentBlockJoinByteKnnVectorQuery extends ParentBlockJoinKnnVec
}
public void testVectorEncodingMismatch() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
Query filter = new TermQuery(new Term("other", "value"));
BitSetProducer parentFilter = parentFilter(reader);
Query kvq =
new DiversifyingChildrenFloatKnnVectorQuery(
"field", new float[] {1, 2}, filter, 2, parentFilter);
assertThrows(IllegalStateException.class, () -> searcher.search(kvq, 3));
try (Directory d = newDirectory()) {
try (IndexWriter w =
new IndexWriter(
d, new IndexWriterConfig().setMergePolicy(newMergePolicy(random(), false)))) {
List<Document> toAdd = new ArrayList<>();
Document doc = new Document();
doc.add(getKnnVectorField("field", new float[] {1, 1}, COSINE));
toAdd.add(doc);
toAdd.add(makeParent(new int[] {1}));
w.addDocuments(toAdd);
}
try (IndexReader reader = DirectoryReader.open(d)) {
IndexSearcher searcher = newSearcher(reader);
BitSetProducer parentFilter = parentFilter(reader);
Query kvq =
new DiversifyingChildrenFloatKnnVectorQuery(
"field", new float[] {1, 2}, null, 2, parentFilter);
assertThrows(IllegalStateException.class, () -> searcher.search(kvq, 3));
}
}
}

View File

@ -29,11 +29,9 @@ import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.store.Directory;
public class TestParentBlockJoinFloatKnnVectorQuery extends ParentBlockJoinKnnVectorQueryTestCase {
@ -50,16 +48,25 @@ public class TestParentBlockJoinFloatKnnVectorQuery extends ParentBlockJoinKnnVe
}
public void testVectorEncodingMismatch() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
Query filter = new TermQuery(new Term("other", "value"));
BitSetProducer parentFilter = parentFilter(reader);
Query kvq =
new DiversifyingChildrenByteKnnVectorQuery(
"field", new byte[] {1, 2}, filter, 2, parentFilter);
assertThrows(IllegalStateException.class, () -> searcher.search(kvq, 3));
try (Directory d = newDirectory()) {
try (IndexWriter w =
new IndexWriter(
d, new IndexWriterConfig().setMergePolicy(newMergePolicy(random(), false)))) {
List<Document> toAdd = new ArrayList<>();
Document doc = new Document();
doc.add(getKnnVectorField("field", new float[] {1, 1}, COSINE));
toAdd.add(doc);
toAdd.add(makeParent(new int[] {1}));
w.addDocuments(toAdd);
}
try (IndexReader reader = DirectoryReader.open(d)) {
IndexSearcher searcher = newSearcher(reader);
BitSetProducer parentFilter = parentFilter(reader);
Query kvq =
new DiversifyingChildrenByteKnnVectorQuery(
"field", new byte[] {1, 2}, null, 2, parentFilter);
assertThrows(IllegalStateException.class, () -> searcher.search(kvq, 3));
}
}
}