mirror of https://github.com/apache/lucene.git
Fix vector type check for diversified knn search (#13235)
I repeatably saw some test failures related to `TestParentBlockJoin[Byte|Float]KnnVectorQuery#testVectorEncodingMismatch`. This commit fixes those test failures and actually checks the field type.
This commit is contained in:
parent
6d120c49a4
commit
c0d81932df
|
@ -136,6 +136,7 @@ public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery {
|
|||
int visitedLimit,
|
||||
KnnCollectorManager knnCollectorManager)
|
||||
throws IOException {
|
||||
ByteVectorValues.checkField(context.reader(), field);
|
||||
KnnCollector collector = knnCollectorManager.newCollector(visitedLimit, context);
|
||||
if (collector == null) {
|
||||
return NO_RESULTS;
|
||||
|
|
|
@ -136,6 +136,7 @@ public class DiversifyingChildrenFloatKnnVectorQuery extends KnnFloatVectorQuery
|
|||
int visitedLimit,
|
||||
KnnCollectorManager knnCollectorManager)
|
||||
throws IOException {
|
||||
FloatVectorValues.checkField(context.reader(), field);
|
||||
KnnCollector collector = knnCollectorManager.newCollector(visitedLimit, context);
|
||||
if (collector == null) {
|
||||
return NO_RESULTS;
|
||||
|
|
|
@ -17,16 +17,21 @@
|
|||
|
||||
package org.apache.lucene.search.join;
|
||||
|
||||
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.KnnByteVectorField;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.TermQuery;
|
||||
import org.apache.lucene.store.Directory;
|
||||
|
||||
public class TestParentBlockJoinByteKnnVectorQuery extends ParentBlockJoinKnnVectorQueryTestCase {
|
||||
|
@ -54,18 +59,27 @@ public class TestParentBlockJoinByteKnnVectorQuery extends ParentBlockJoinKnnVec
|
|||
}
|
||||
|
||||
public void testVectorEncodingMismatch() throws IOException {
|
||||
try (Directory indexStore =
|
||||
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
|
||||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
try (Directory d = newDirectory()) {
|
||||
try (IndexWriter w =
|
||||
new IndexWriter(
|
||||
d, new IndexWriterConfig().setMergePolicy(newMergePolicy(random(), false)))) {
|
||||
List<Document> toAdd = new ArrayList<>();
|
||||
Document doc = new Document();
|
||||
doc.add(getKnnVectorField("field", new float[] {1, 1}, COSINE));
|
||||
toAdd.add(doc);
|
||||
toAdd.add(makeParent(new int[] {1}));
|
||||
w.addDocuments(toAdd);
|
||||
}
|
||||
try (IndexReader reader = DirectoryReader.open(d)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
Query filter = new TermQuery(new Term("other", "value"));
|
||||
BitSetProducer parentFilter = parentFilter(reader);
|
||||
Query kvq =
|
||||
new DiversifyingChildrenFloatKnnVectorQuery(
|
||||
"field", new float[] {1, 2}, filter, 2, parentFilter);
|
||||
"field", new float[] {1, 2}, null, 2, parentFilter);
|
||||
assertThrows(IllegalStateException.class, () -> searcher.search(kvq, 3));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static byte[] fromFloat(float[] queryVector) {
|
||||
byte[] query = new byte[queryVector.length];
|
||||
|
|
|
@ -29,11 +29,9 @@ import org.apache.lucene.index.DirectoryReader;
|
|||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.TermQuery;
|
||||
import org.apache.lucene.store.Directory;
|
||||
|
||||
public class TestParentBlockJoinFloatKnnVectorQuery extends ParentBlockJoinKnnVectorQueryTestCase {
|
||||
|
@ -50,18 +48,27 @@ public class TestParentBlockJoinFloatKnnVectorQuery extends ParentBlockJoinKnnVe
|
|||
}
|
||||
|
||||
public void testVectorEncodingMismatch() throws IOException {
|
||||
try (Directory indexStore =
|
||||
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
|
||||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
try (Directory d = newDirectory()) {
|
||||
try (IndexWriter w =
|
||||
new IndexWriter(
|
||||
d, new IndexWriterConfig().setMergePolicy(newMergePolicy(random(), false)))) {
|
||||
List<Document> toAdd = new ArrayList<>();
|
||||
Document doc = new Document();
|
||||
doc.add(getKnnVectorField("field", new float[] {1, 1}, COSINE));
|
||||
toAdd.add(doc);
|
||||
toAdd.add(makeParent(new int[] {1}));
|
||||
w.addDocuments(toAdd);
|
||||
}
|
||||
try (IndexReader reader = DirectoryReader.open(d)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
Query filter = new TermQuery(new Term("other", "value"));
|
||||
BitSetProducer parentFilter = parentFilter(reader);
|
||||
Query kvq =
|
||||
new DiversifyingChildrenByteKnnVectorQuery(
|
||||
"field", new byte[] {1, 2}, filter, 2, parentFilter);
|
||||
"field", new byte[] {1, 2}, null, 2, parentFilter);
|
||||
assertThrows(IllegalStateException.class, () -> searcher.search(kvq, 3));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testScoreCosine() throws IOException {
|
||||
try (Directory d = newDirectory()) {
|
||||
|
|
Loading…
Reference in New Issue